generic range changes that work for str + int (#13350)

* generic range changes that work for str + int

* opt range counts up
This commit is contained in:
George Hotz
2025-11-19 08:07:49 -08:00
committed by GitHub
parent 1a72ac16a6
commit 225eb1500f
2 changed files with 6 additions and 7 deletions

View File

@@ -18,6 +18,7 @@ class Scheduler:
self.ast, self.ren = ast, ren
self.dont_use_locals = self.ast.arg.dont_use_locals if self.ast.arg is not None else False
self.applied_opts = list(self.ast.arg.applied_opts) if self.ast.arg is not None else []
self.opt_range = itertools.count(start=max([x.arg[0] for x in self.rngs], default=0)+1)
@property
def rngs(self):
@@ -29,8 +30,6 @@ class Scheduler:
def full_shape(self): return [ssimplify(x.src[0]) for x in self.rngs]
@property
def axis_types(self): return [x.arg[-1] for x in self.rngs]
@property
def maxarg(self): return max([x.arg[0] for x in self.rngs], default=0)
# strings like ['g0', 'g1', 'l0', 'l1', 'l2', 'l3', 'l4', 'l5', 'R0', 'r0', 'r1', 'r2', 'u0', 'u1', 'u2']
def shape_str(self) -> list[str]:
@@ -95,7 +94,7 @@ class Scheduler:
def shift_to(self, rng:UOp, amount:int, new_type:AxisType, top:bool=False, input_new_rng=None):
if (old_sz:=rng.src[0].divides(amount)) is None:
raise KernelOptError(f"{amount} can't divide {rng.src[0]} in {self.colored_shape()}")
new_rng = UOp.range(amount, self.maxarg+1, new_type) if input_new_rng is None else input_new_rng
new_rng = UOp.range(amount, next(self.opt_range), new_type) if input_new_rng is None else input_new_rng
replaced_rng = rng.replace(src=(UOp.const(dtypes.int, old_sz),))
sub_axis = (new_rng * old_sz + replaced_rng) if top else (replaced_rng * amount + new_rng)
self.ast = self.ast.substitute({rng:sub_axis}, name=f"shift {rng.arg[:-1]} {amount} {str(new_type).split('.')[1].lower()}")
@@ -231,9 +230,9 @@ class Scheduler:
for tc in tensor_cores:
if tc.dtype_in == in0.dtype.scalar() and tc.dtype_in == in1.dtype.scalar() and tc.dtype_out == reduceop.dtype.scalar():
# tensor cores have three ranges. X, Y, and REDUCE
in0_ranges = sorted([u for u in in0.ranges if u not in in1.ranges], key=lambda x: -x.arg[0])
in1_ranges = sorted([u for u in in1.ranges if u not in in0.ranges], key=lambda x: -x.arg[0])
red_ranges = sorted(reduceop.src[1:], key=lambda x: -x.arg[0])
in0_ranges = sorted([u for u in in0.ranges if u not in in1.ranges], key=lambda x: x.arg[0], reverse=True)
in1_ranges = sorted([u for u in in1.ranges if u not in in0.ranges], key=lambda x: x.arg[0], reverse=True)
red_ranges = sorted(reduceop.src[1:], key=lambda x: x.arg[0], reverse=True)
if DEBUG >= 3:
print(f"TC({axis}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}",
f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")

View File

@@ -1350,7 +1350,7 @@ pm_pyrender_extra = PatternMatcher([
(UPat(Ops.REDUCE_AXIS, name="r"), lambda ctx,r: f"{ctx[r.src[0]]}.r({r.arg[0]}, {r.arg[1]})"),
# NOTE: range has srcs sometimes after control flow
(UPat(Ops.RANGE, src=(UPat(Ops.CONST, name="c"),), allow_any_len=True, name="x"), lambda ctx,x,c:
"UOp.range("+', '.join([str(c.arg)] + [str(y) for y in x.arg])+
"UOp.range("+', '.join([str(c.arg)] + [repr(y) for y in x.arg])+
(f', src={srcs(ctx, x.src[1:])}' if len(x.src) > 1 else '')+(', dtype='+str(x.dtype) if x.dtype is not dtypes.index else '')+")"),
# TODO: index shouldn't mismatch dtype
(UPat(Ops.INDEX, src=(UPat(), UPat()), allow_any_len=True, name="x"), lambda ctx,x: