range has only one src now [pr] (#10100)

* range has only one op now

* fix z3 checker

* ci fix

* needs shell

* try pip ensure update

* that ensurepip is useless

* upgrade pip before cache

* windows happy?
This commit is contained in:
George Hotz
2025-04-29 10:31:05 -04:00
committed by GitHub
parent 427471550a
commit c3ff308abb
14 changed files with 34 additions and 32 deletions

View File

@@ -44,6 +44,9 @@ runs:
uses: actions/setup-python@v5
with:
python-version: ${{ inputs.python-version }}
- name: Upgrade pip
shell: bash
run: python -m pip install --upgrade pip
# **** Caching packages ****
# TODO: key should include input.deps, but it can't since it can't contain commas

View File

@@ -558,11 +558,10 @@ class TestUOpGraph(unittest.TestCase):
def test_switched_range_order(self):
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
c0 = UOp.const(dtypes.int, 0)
c2 = UOp.const(dtypes.int, 2)
cf = UOp.const(dtypes.float, 0.0)
r1 = UOp(Ops.RANGE, dtypes.int, (c0, c2), 0)
r2 = UOp(Ops.RANGE, dtypes.int, (c0, c2), 1)
r1 = UOp(Ops.RANGE, dtypes.int, (c2,), 0)
r2 = UOp(Ops.RANGE, dtypes.int, (c2,), 1)
alu = UOp(Ops.MUL, dtypes.int, (r2, r1))
store = UOp(Ops.STORE, dtypes.void, (glbl.index(alu), cf))
uops = to_uops_list([store])

View File

@@ -64,21 +64,21 @@ class TestFoldingAndReduction(unittest.TestCase):
def test_full_graph_rewrite_reduction_with_unused_range(self):
const1 = UOp.const(dtypes.int32, 15)
const2 = UOp.const(dtypes.int32, 25)
rng = UOp.range(dtypes.int32, 0, 10, idx=0)
rng = UOp.range(dtypes.int32, 10, idx=0)
optimized_sink = apply_rewrite((const1 + const2).reduce(Ops.ADD, rng))
expected_sum = 10 * (15 + 25)
self.assertEqual(optimized_sink.arg, expected_sum)
@unittest.skip("currently failing")
def test_full_graph_rewrite_range_reduction(self):
simple_range = UOp.range(dtypes.int32, 0, 5, idx=0)
simple_range = UOp.range(dtypes.int32, 5, idx=0)
optimized_sink = apply_rewrite(simple_range.reduce(Ops.ADD, simple_range))
expected_sum = sum(range(5))
self.assertEqual(optimized_sink.arg, expected_sum)
@unittest.skip("currently failing")
def test_full_graph_rewrite_simple_reduction_folding(self):
simple_range = UOp.range(dtypes.int32, 0, 4, idx=0)
simple_range = UOp.range(dtypes.int32, 4, idx=0)
add_uop = simple_range + UOp.const(dtypes.int32, 1)
optimized_sink = apply_rewrite(add_uop.reduce(Ops.ADD, simple_range))
expected_sum = sum(i + 1 for i in range(4))
@@ -86,8 +86,8 @@ class TestFoldingAndReduction(unittest.TestCase):
@unittest.skip("currently failing")
def test_full_graph_rewrite_nested_loop_collapse(self):
outer_range = UOp.range(dtypes.int32, 0, 8, 0)
inner_range = UOp.range(dtypes.int32, 0, 4, 1)
outer_range = UOp.range(dtypes.int32, 8, 0)
inner_range = UOp.range(dtypes.int32, 4, 1)
expr = (outer_range * 10) + inner_range
optimized_reduce_uop = apply_rewrite(expr.reduce(Ops.ADD, outer_range, inner_range))
self.assertEqual(optimized_reduce_uop.op, Ops.CONST)

View File

@@ -19,7 +19,7 @@ def get_load_image_uop(image_shape:tuple[int, ...], valid:UOp, idx:tuple[UOp, UO
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.int, (), (expr, nmax))
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
def Range(n, nmax): return UOp(Ops.RANGE, dtypes.int, arg=n, src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, nmax),))
def Range(n, nmax): return UOp(Ops.RANGE, dtypes.int, arg=n, src=(UOp.const(dtypes.int, nmax),))
class TestHelpers(unittest.TestCase):
def test_is_increasing(self):
@@ -39,7 +39,7 @@ class TestHelpers(unittest.TestCase):
self.assertTrue(f2.is_increasing())
self.assertTrue(f3.is_increasing())
rng = UOp(Ops.RANGE, dtypes.int, arg=(2, True), src=(UOp(Ops.CONST, dtypes.int, arg=0, src=()), UOp(Ops.CONST, dtypes.int, arg=5, src=()),))
rng = UOp(Ops.RANGE, dtypes.int, arg=(2, True), src=(UOp(Ops.CONST, dtypes.int, arg=5, src=()),))
self.assertTrue(rng.is_increasing())
self.assertTrue((rng+2).is_increasing())

View File

@@ -352,7 +352,7 @@ def reduce_rangeless(red:UOp):
ret = red.src[0]
if red.arg is Ops.ADD:
for r in red.src[1:]:
ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
return ret
def no_range(u:UOp) -> bool: return not any(x.op is Ops.RANGE for x in u.sparents)
@@ -367,9 +367,9 @@ pm_reduce_collapse = PatternMatcher([
((UPat.var("x")+UPat.var("y")) != UPat.var("c"), lambda x,y,c: (x != (c-y)) if no_range(y) and no_range(c) else None),
# fold the range
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(UPat(Ops.CONST, arg=0), UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True),
lambda r,cut,val: (r.src[1]-cut).maximum(0).minimum(r.src[1]-r.src[0]).cast(val.dtype) * val),
lambda r,cut,val: (r.src[0]-cut).maximum(0).minimum(r.src[0]).cast(val.dtype) * val),
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(UPat.cvar("val"), 0).reduce(arg=Ops.ADD, allow_any_len=True),
lambda r,cut,val: (cut-r.src[0]).maximum(0).minimum(r.src[1]-r.src[0]).cast(val.dtype) * val),
lambda r,cut,val: cut.maximum(0).minimum(r.src[0]).cast(val.dtype) * val),
# devectorize REDUCE
(UPat(Ops.VECTORIZE, name="inp").reduce(name="red", allow_any_len=True), no_vectorized_reduce),
# REDUCE on ADD
@@ -385,7 +385,7 @@ pm_reduce_collapse = PatternMatcher([
lambda buf,idx,gate: buf.index(idx, gate.logical_not()).load()),
# INDEX on RANGE / gated RANGE
(UPat.var("buf").index(UPat.var("expr"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted())),
lambda buf,r,idx,expr: buf.index(expr.substitute({r:idx.cast(r.dtype)}), (idx.cast(r.dtype) >= r.src[0]) & (idx.cast(r.dtype) < r.src[1]))),
lambda buf,r,idx,expr: buf.index(expr.substitute({r:idx.cast(r.dtype)}), (idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0]))),
# index/load. TODO: this is more aggressive than needed
(UPat((Ops.INDEX, Ops.LOAD), name="alu"), no_vectorized_alu),
# AND on WHERE
@@ -416,7 +416,7 @@ def reduce_unparented(red:UOp):
if len(reduce_unparented) == 0: return None
ret = red.replace(src=(red.src[0],)+tuple(reduce_parented)) if len(reduce_parented) or red.dtype != red.src[0].dtype else red.src[0]
if red.arg is Ops.ADD:
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
for r in reduce_unparented: ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
return ret
pm_reduce = PatternMatcher([

View File

@@ -108,10 +108,10 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
else:
# all loops are RANGES
idxs = [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), i) for i,g in enumerate(full_shape[:first_reduce])]
idxs = [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(g),), i) for i,g in enumerate(full_shape[:first_reduce])]
# reduce loops
idxs += [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(g)), i)
idxs += [UOp(Ops.RANGE, dtypes.int, (sint_to_uop(g),), i)
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
# upcast loops
@@ -122,7 +122,7 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
# late indexes (group for reduce)
ridxs = idxs[:]
for a in range(first_reduce, first_reduce+group_for_reduces):
ridxs[a] = UOp(Ops.RANGE, dtypes.int, (sint_to_uop(0), sint_to_uop(full_shape[a])), 1000+a)
ridxs[a] = UOp(Ops.RANGE, dtypes.int, (sint_to_uop(full_shape[a]),), 1000+a)
return IndexContext(idxs, ridxs)

View File

@@ -419,7 +419,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
unmasked_st = ShapeTracker.from_shape(()).reshape((1,)*len(st.shape)).expand(st.shape).to_uop()
return UOp(Ops.VALID, dtypes.bool, (st.to_uop(),)).where(self.replace(src=(unmasked_st,)), UOp.const(self.dtype, 0).replace(src=(unmasked_st,)))
@staticmethod
def range(dtype:DType, start:sint, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(start), sint_to_uop(end)), arg=idx)
def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx)
def r(self, op:Ops, axis:tuple[int, ...]):
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
return self if len(axis) == 0 else UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
@@ -645,7 +645,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.WHERE and dtypes.is_int(self.dtype): return min(self.src[1].vmin, self.src[2].vmin), max(self.src[1].vmax, self.src[2].vmax)
# NOTE: returned UOp is assumed to be CONST
if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
if self.op is Ops.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
if self.op is Ops.RANGE: return 0, (self.src[0]-1).vmax
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
if self.op in {Ops.UNROLL, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
# TODO: Ops.SPECIAL is Ops.DEFINE_VAR

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Optional, Callable
from typing import Optional, Callable, cast
import functools, math
from enum import Enum, auto
from dataclasses import dataclass, field, replace
@@ -69,7 +69,7 @@ class Estimates:
for u in uops:
if u.op is Ops.RANGE:
mult_stack.append(mults)
mults *= (u.src[1] - u.src[0]).ssimplify()
mults *= cast(sint, u.src[0].ssimplify())
# SPECIAL are already counted in mults
mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults
elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1)

View File

@@ -15,7 +15,7 @@ base_rewrite = PatternMatcher([
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"),
# r method accesses
(UPat(Ops.RANGE, name="x"),
lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = {ctx[x.src[0]]}; {ctx[x]} < {ctx[x.src[1]]}; {ctx[x]}++) {{"),
lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = 0; {ctx[x]} < {ctx[x.src[0]]}; {ctx[x]}++) {{"),
(UPat(Ops.VECTORIZE, name="x"),
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
(f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device in {'CPU', 'DSP'} else f"({','.join([ctx[y] for y in x.src])})")),

View File

@@ -98,10 +98,10 @@ base_rewrite = PatternMatcher([
(UPat(Ops.RANGE, name="x"), lambda ctx,x:
f" br label %loop_entry_{x.arg}\nloop_entry_{x.arg}:\n"
f" br label %loop_body_{x.arg}\nloop_body_{x.arg}:\n"
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg}], [{ctx[x]}phi, %loop_latch_{x.arg}]"),
f" {ctx[x]} = phi {ldt(x.dtype)} [0, %loop_entry_{x.arg}], [{ctx[x]}phi, %loop_latch_{x.arg}]"),
(UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
f" br label %loop_latch_{x.src[0].arg}\nloop_latch_{x.src[0].arg}:\n"
f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[1]]}\n"
f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[0]]}\n"
f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg}, label %loop_exit_{x.src[0].arg}\nloop_exit_{x.src[0].arg}:"),
# if

View File

@@ -112,12 +112,12 @@ string_rewrite = PatternMatcher([
f"setp.ne.s16 {ctx.r[pred]}, {render_val(pred.arg, pred.dtype)}, 0;", f"mov.pred {ctx.r[x]}, {ctx.r[pred]};"]),
(UPat(Ops.DEFINE_ACC, name="x", src=(UPat.cvar("pred"),), allow_any_len=True),
lambda ctx, x, pred: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(pred.arg, x.dtype)};"),
(UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, {ctx.r[x.src[0]]};", "LOOP_" + f"{ctx.r[x][1:]}:"]),
(UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, 0;", "LOOP_" + f"{ctx.r[x][1:]}:"]),
(UPat(Ops.ASSIGN, name="x", dtype=dtypes.bool), lambda ctx, x: [f"mov.pred {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"]),
(UPat(Ops.ASSIGN, name="x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"),
(UPat(Ops.ENDRANGE, name="x", src=(UPat.var("src0"),)), lambda ctx, x, src0: [
ctx.code_for_op[Ops.ADD](ctx.r[src0], ctx.r[src0], "1", dtypes.int, ctx.types[dtypes.int]),
ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[1]], dtypes.int, ctx.types[dtypes.int]),
ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[0]], dtypes.int, ctx.types[dtypes.int]),
f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]),
(UPat(Ops.DEFINE_LOCAL, name="x"),
lambda ctx, x: [f".shared .align 4 .b8 {x.arg}[{x.dtype.size*x.dtype.itemsize}];", f"mov.u64 {ctx.r[x]}, {x.arg}[0];"]),

View File

@@ -91,11 +91,11 @@ class PythonProgram:
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
ul[i] = inp[0]
elif uop is Ops.RANGE:
if i not in ul: ul[i] = [inp[0][0]] * warp_size
if i not in ul: ul[i] = [0] * warp_size
else:
for j in range(len(ul[i])):
ul[i][j] += 1
if ul[i][0] == inp[1][0]:
if ul[i][0] == inp[0][0]:
del ul[i]
i = loop_ends[i] + 1
continue

View File

@@ -89,7 +89,7 @@ class View:
def to_indexed_uops(self:View, idxs:Optional[Sequence[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> tuple[UOp, UOp]:
"""(idx, valid)"""
if idxs is None: idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)]
if idxs is None: idxs = [UOp.range(dtypes.int, s, i) for i,s in enumerate(self.shape)]
iexpr = sint_to_uop(self.offset)
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else itertools.repeat(None)):
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st

View File

@@ -20,7 +20,7 @@ try:
(UPat(Ops.SPECIAL, src=(), name="x"), lambda x: UOp(Ops.SPECIAL, arg=x.arg[0], src=(x.ufix(x.arg[1]),))),
(UPat(Ops.SPECIAL, src=UPat(Ops.NOOP), name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(x.arg, 0, x.src[0].arg-1, ctx[0]))),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(x.arg[0], x.arg[1], x.arg[2], ctx[0]))),
(UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"ridx{x.arg}", x.src[0].arg, x.src[1].arg-1, ctx[0]))),
(UPat(Ops.RANGE, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"ridx{x.arg}", 0, x.src[0].arg-1, ctx[0]))),
(UPat(Ops.LOAD, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=create_bounded(f"load{ctx[1].setdefault(x, len(ctx[1]))}", x.vmin, x.vmax, ctx[0]))),
(UPat(Ops.CONST, name="x"), lambda x,ctx: UOp(Ops.NOOP, arg=(z3.BoolVal if dtypes.is_bool(x.dtype) else z3.IntVal)(x.arg, ctx=ctx[0].ctx))),
(UPat(Ops.CAST, name="x"), lambda x: x.src[0]),
@@ -112,7 +112,7 @@ spec = PatternMatcher([
lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
(UPat(Ops.DEFINE_VAR, name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
(UPat(Ops.RANGE, src=(UPat.var("x"), UPat.var("y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype and isinstance(rng.arg, int)),
(UPat(Ops.RANGE, src=(UPat.var("x"),), name="rng"), lambda rng,x: rng.dtype == x.dtype and isinstance(rng.arg, int)),
(UPat(Ops.SPECIAL, src=()), lambda: True),
# TODO: confirm the args of both of these are shapetrackers