mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
write scan in uops (#13321)
* write scan in uops * ops range * no need for variable * meh, later * shorter
This commit is contained in:
@@ -50,8 +50,36 @@ class TestOuterRange(unittest.TestCase):
|
||||
# 3 matmuls with outer world range
|
||||
i = UOp.range(3, -100, AxisType.OUTER)
|
||||
vec_i = Tensor(vec.uop.after(i))
|
||||
vi = UOp.variable("i", i.vmin, i.vmax).bind(i)
|
||||
out = Tensor(vec.uop.after(vec_i.uop.store((vec_i.contiguous() @ mats[vi]).uop).end(i)))
|
||||
comp = vec_i.contiguous() @ mats[i]
|
||||
store = vec_i.uop.store(comp.uop).end(i)
|
||||
out = Tensor(vec.uop.after(store))
|
||||
out.realize()
|
||||
|
||||
# TODO: testing allclose
|
||||
assert Tensor.allclose(ref, out, atol=1e-6), f"{ref.numpy()=}, {out.numpy()=}"
|
||||
|
||||
class TestOuterScan(unittest.TestCase):
|
||||
def _test_scan(self):
|
||||
vec = Tensor.randn(1, 10).realize()
|
||||
mats = Tensor.randn(3, 10, 10).realize()
|
||||
|
||||
# 3 matmuls in "scan"
|
||||
vec1 = vec @ mats[0]
|
||||
vec2 = vec1 @ mats[1]
|
||||
vec3 = vec2 @ mats[2]
|
||||
ref = Tensor.stack(vec1, vec2, vec3)
|
||||
ref.realize()
|
||||
return vec, mats, ref
|
||||
|
||||
def test_uop_scan_matmul(self):
|
||||
vec, mats, ref = self._test_scan()
|
||||
|
||||
# 3 matmuls with SCAN
|
||||
i = UOp.range(3, -100, AxisType.OUTER)
|
||||
out = Tensor.empty(3, 1, 10)
|
||||
comp = Tensor(i.eq(0).where(vec.uop, out[(i-1).maximum(0)].uop)) @ mats[i]
|
||||
store = out[i].uop.store(comp.uop).end(i)
|
||||
out = Tensor(out.uop.after(store))
|
||||
out.realize()
|
||||
|
||||
# TODO: testing allclose
|
||||
|
||||
@@ -23,10 +23,12 @@ def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[
|
||||
in_degree: dict[UOp, int] = {}
|
||||
var_vals: dict[str, int] = {}
|
||||
for u in sched_sink.toposort():
|
||||
if u.op is not Ops.AFTER: continue # anything that's not an ASSIGN doesn't write a kernel, so we can skip
|
||||
if u.op is Ops.RANGE:
|
||||
in_degree.setdefault(u, 0)
|
||||
continue
|
||||
if u.op is not Ops.AFTER or u.src[1].op is Ops.RANGE: continue
|
||||
k = u.src[1]
|
||||
in_degree.setdefault(k, 0)
|
||||
if k.op is Ops.RANGE: continue
|
||||
for s in k.src[0].src if k.op is Ops.END else k.src:
|
||||
if s.op is Ops.AFTER:
|
||||
children[s.src[1]].append(k)
|
||||
|
||||
@@ -2,7 +2,7 @@ from dataclasses import dataclass, field
|
||||
import itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, ssimplify, KernelInfo
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
|
||||
from tinygrad.helpers import PCONTIG, partition, get_single_element, unwrap, disable_gc
|
||||
@@ -397,6 +397,9 @@ def handle_after(ctx:LocalAddBufferContext, after:UOp):
|
||||
|
||||
def renumber_range(ctx:LocalAddBufferContext, r:UOp):
|
||||
if r.tag != (): return None
|
||||
if r.arg[-1] == AxisType.OUTER:
|
||||
# for outer range, we replace with a bound variable
|
||||
return UOp.variable("range_"+range_str(r), r.vmin, r.vmax).bind(r.replace(tag=None))
|
||||
ret = r.replace(arg=(ctx.range,)+r.arg[1:], tag=None)
|
||||
ctx.range += 1
|
||||
return ret
|
||||
@@ -571,12 +574,12 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
|
||||
if assign_rep: tsink = graph_rewrite(tsink, _substitute, ctx=assign_rep, bottom_up=True, name="fix_assign")
|
||||
|
||||
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
|
||||
|
||||
# TODO: we can probably get this earlier
|
||||
sink_tags = [s.tag for s in tsink.src]
|
||||
tsink = graph_rewrite(tsink, _remove_all_tags, name="remove all tags")
|
||||
|
||||
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
|
||||
|
||||
becomes_map: dict[UOp, UOp] = {}
|
||||
for tag, s in zip(sink_tags, tsink.src):
|
||||
assert tag is not None
|
||||
|
||||
@@ -42,7 +42,7 @@ shared_spec = PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat(),), allow_any_len=True, name="x"), lambda x: all(y.dtype == dtypes.index for y in x.src[1:]) or None),
|
||||
|
||||
# RANGE/SPECIAL define loops, END closes them
|
||||
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE)), dtype=dtypes.void), lambda: True),
|
||||
(UPat(Ops.END, src=(UPat(), UPat(Ops.RANGE))), lambda: True),
|
||||
])
|
||||
|
||||
# ***** UOp spec in the Tensor graph *****
|
||||
@@ -171,7 +171,7 @@ kernel_spec = PatternMatcher([
|
||||
(UPat(Ops.UNROLL, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
||||
|
||||
# END can end multiple axes here
|
||||
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True, dtype=dtypes.void), lambda: True),
|
||||
(UPat(Ops.END, src=(UPat(), UPat()), allow_any_len=True), lambda: True),
|
||||
|
||||
# bufferize can be on anything
|
||||
(UPat(Ops.BUFFERIZE, src=(UPat(),), allow_any_len=True, name="x"), lambda x: True),
|
||||
|
||||
Reference in New Issue
Block a user