write scan in uops (#13321)

* write scan in uops

* ops range

* no need for variable

* meh, later

* shorter
This commit is contained in:
George Hotz
2025-11-17 16:58:08 -08:00
committed by GitHub
parent 8894a5409d
commit e4fead8a86
4 changed files with 42 additions and 9 deletions

View File

@@ -50,8 +50,36 @@ class TestOuterRange(unittest.TestCase):
# 3 matmuls with outer world range
i = UOp.range(3, -100, AxisType.OUTER)
vec_i = Tensor(vec.uop.after(i))
vi = UOp.variable("i", i.vmin, i.vmax).bind(i)
out = Tensor(vec.uop.after(vec_i.uop.store((vec_i.contiguous() @ mats[vi]).uop).end(i)))
comp = vec_i.contiguous() @ mats[i]
store = vec_i.uop.store(comp.uop).end(i)
out = Tensor(vec.uop.after(store))
out.realize()
# TODO: testing allclose
assert Tensor.allclose(ref, out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
class TestOuterScan(unittest.TestCase):
def _test_scan(self):
vec = Tensor.randn(1, 10).realize()
mats = Tensor.randn(3, 10, 10).realize()
# 3 matmuls in "scan"
vec1 = vec @ mats[0]
vec2 = vec1 @ mats[1]
vec3 = vec2 @ mats[2]
ref = Tensor.stack(vec1, vec2, vec3)
ref.realize()
return vec, mats, ref
def test_uop_scan_matmul(self):
vec, mats, ref = self._test_scan()
# 3 matmuls with SCAN
i = UOp.range(3, -100, AxisType.OUTER)
out = Tensor.empty(3, 1, 10)
comp = Tensor(i.eq(0).where(vec.uop, out[(i-1).maximum(0)].uop)) @ mats[i]
store = out[i].uop.store(comp.uop).end(i)
out = Tensor(out.uop.after(store))
out.realize()
# TODO: testing allclose

View File

@@ -23,10 +23,12 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
in_degree: dict[UOp, int] = {}
var_vals: dict[str, int] = {}
for u in sched_sink.toposort():
if u.op is not Ops.AFTER: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip
if u.op is Ops.RANGE:
in_degree.setdefault(u, 0)
continue
if u.op is not Ops.AFTER or u.src[1].op is Ops.RANGE: continue
k = u.src[1]
in_degree.setdefault(k, 0)
if k.op is Ops.RANGE: continue
for s in k.src[0].src if k.op is Ops.END else k.src:
if s.op is Ops.AFTER:
children[s.src[1]].append(k)

View File

@@ -2,7 +2,7 @@ from dataclasses import dataclass, field
import itertools
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
from tinygrad.helpers import PCONTIG, partition, get_single_element, unwrap, disable_gc
@@ -397,6 +397,9 @@ def handle_after(ctx:LocalAddBufferContext, after:UOp):
def renumber_range(ctx:LocalAddBufferContext, r:UOp):
if r.tag != (): return None
if r.arg[-1] == AxisType.OUTER:
# for outer range, we replace with a bound variable
return UOp.variable("range_"+range_str(r), r.vmin, r.vmax).bind(r.replace(tag=None))
ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=None)
ctx.range += 1
return ret
@@ -571,12 +574,12 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
if assign_rep: tsink = graph_rewrite(tsink, _substitute, ctx=assign_rep, bottom_up=True, name="fix_assign")
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
# TODO: we can probably get this earlier
sink_tags = [s.tag for s in tsink.src]
tsink = graph_rewrite(tsink, _remove_all_tags, name="remove all tags")
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
becomes_map: dict[UOp, UOp] = {}
for tag, s in zip(sink_tags, tsink.src):
assert tag is not None

View File

@@ -42,7 +42,7 @@ shared_spec = PatternMatcher([
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
# RANGE/SPECIAL define loops, END closes them
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE))), lambda: True),
])
# ***** UOp spec in the Tensor graph *****
@@ -171,7 +171,7 @@ kernel_spec = PatternMatcher([
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
# END can end multiple axes here
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True, dtype=dtypes.void), lambda: True),
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True), lambda: True),
# bufferize can be on anything
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: True),