mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
generic range changes that work for str + int (#13350)
* generic range changes that work for str + int * opt range counts up
This commit is contained in:
@@ -18,6 +18,7 @@ class Scheduler:
|
||||
self.ast, self.ren = ast, ren
|
||||
self.dont_use_locals = self.ast.arg.dont_use_locals if self.ast.arg is not None else False
|
||||
self.applied_opts = list(self.ast.arg.applied_opts) if self.ast.arg is not None else []
|
||||
self.opt_range = itertools.count(start=max([x.arg[0] for x in self.rngs], default=0)+1)
|
||||
|
||||
@property
|
||||
def rngs(self):
|
||||
@@ -29,8 +30,6 @@ class Scheduler:
|
||||
def full_shape(self): return [ssimplify(x.src[0]) for x in self.rngs]
|
||||
@property
|
||||
def axis_types(self): return [x.arg[-1] for x in self.rngs]
|
||||
@property
|
||||
def maxarg(self): return max([x.arg[0] for x in self.rngs], default=0)
|
||||
|
||||
# strings like ['g0', 'g1', 'l0', 'l1', 'l2', 'l3', 'l4', 'l5', 'R0', 'r0', 'r1', 'r2', 'u0', 'u1', 'u2']
|
||||
def shape_str(self) -> list[str]:
|
||||
@@ -95,7 +94,7 @@ class Scheduler:
|
||||
def shift_to(self, rng:UOp, amount:int, new_type:AxisType, top:bool=False, input_new_rng=None):
|
||||
if (old_sz:=rng.src[0].divides(amount)) is None:
|
||||
raise KernelOptError(f"{amount} can't divide {rng.src[0]} in {self.colored_shape()}")
|
||||
new_rng = UOp.range(amount, self.maxarg+1, new_type) if input_new_rng is None else input_new_rng
|
||||
new_rng = UOp.range(amount, next(self.opt_range), new_type) if input_new_rng is None else input_new_rng
|
||||
replaced_rng = rng.replace(src=(UOp.const(dtypes.int, old_sz),))
|
||||
sub_axis = (new_rng * old_sz + replaced_rng) if top else (replaced_rng * amount + new_rng)
|
||||
self.ast = self.ast.substitute({rng:sub_axis}, name=f"shift {rng.arg[:-1]} {amount} {str(new_type).split('.')[1].lower()}")
|
||||
@@ -231,9 +230,9 @@ class Scheduler:
|
||||
for tc in tensor_cores:
|
||||
if tc.dtype_in == in0.dtype.scalar() and tc.dtype_in == in1.dtype.scalar() and tc.dtype_out == reduceop.dtype.scalar():
|
||||
# tensor cores have three ranges. X, Y, and REDUCE
|
||||
in0_ranges = sorted([u for u in in0.ranges if u not in in1.ranges], key=lambda x: -x.arg[0])
|
||||
in1_ranges = sorted([u for u in in1.ranges if u not in in0.ranges], key=lambda x: -x.arg[0])
|
||||
red_ranges = sorted(reduceop.src[1:], key=lambda x: -x.arg[0])
|
||||
in0_ranges = sorted([u for u in in0.ranges if u not in in1.ranges], key=lambda x: x.arg[0], reverse=True)
|
||||
in1_ranges = sorted([u for u in in1.ranges if u not in in0.ranges], key=lambda x: x.arg[0], reverse=True)
|
||||
red_ranges = sorted(reduceop.src[1:], key=lambda x: x.arg[0], reverse=True)
|
||||
if DEBUG >= 3:
|
||||
print(f"TC({axis}): {[(x.arg[0],x.vmax+1) for x in in0_ranges]}",
|
||||
f"{[(x.arg[0],x.vmax+1) for x in in1_ranges]} {[(x.arg[0],x.vmax+1) for x in red_ranges]}")
|
||||
|
||||
@@ -1350,7 +1350,7 @@ pm_pyrender_extra = PatternMatcher([
|
||||
(UPat(Ops.REDUCE_AXIS, name="r"), lambda ctx,r: f"{ctx[r.src[0]]}.r({r.arg[0]}, {r.arg[1]})"),
|
||||
# NOTE: range has srcs sometimes after control flow
|
||||
(UPat(Ops.RANGE, src=(UPat(Ops.CONST, name="c"),), allow_any_len=True, name="x"), lambda ctx,x,c:
|
||||
"UOp.range("+', '.join([str(c.arg)] + [str(y) for y in x.arg])+
|
||||
"UOp.range("+', '.join([str(c.arg)] + [repr(y) for y in x.arg])+
|
||||
(f', src={srcs(ctx, x.src[1:])}' if len(x.src) > 1 else '')+(', dtype='+str(x.dtype) if x.dtype is not dtypes.index else '')+")"),
|
||||
# TODO: index shouldn't mismatch dtype
|
||||
(UPat(Ops.INDEX, src=(UPat(), UPat()), allow_any_len=True, name="x"), lambda ctx,x:
|
||||
|
||||
Reference in New Issue
Block a user