mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
works
This commit is contained in:
@@ -3,11 +3,16 @@ import unittest
|
||||
from tinygrad import Tensor
|
||||
|
||||
class TestScan(unittest.TestCase):
|
||||
def test_reduce_add(self):
|
||||
a = Tensor.randn(10, 10).realize()
|
||||
a_red = a.sum(axis=1)
|
||||
np.testing.assert_allclose(a_red.numpy(), a.numpy().sum(axis=1), atol=1e-6)
|
||||
|
||||
def test_scan_add(self):
|
||||
a = Tensor.randn(10, 10).realize()
|
||||
init = Tensor.zeros(10, 1)
|
||||
a_red = (a+init).scan(init)
|
||||
np.testing.assert_allclose(a_red.numpy(), a.numpy().sum(axis=1))
|
||||
init = Tensor.zeros(10, 1).contiguous()
|
||||
a_red = (a+init).scan(init).reshape(10)
|
||||
np.testing.assert_allclose(a_red.numpy(), a.numpy().sum(axis=1), atol=1e-6)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -308,9 +308,19 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
||||
if len(reduce_range) == 0: return ret
|
||||
return acc.after(acc.index(UOp.const(dtypes.int, 0)).store(ret).end(*reduce_range)).index(UOp.const(dtypes.int, 0))
|
||||
|
||||
def scan_to_store(x:UOp):
|
||||
_, acc, ranges = x.src[0], x.src[1], x.src[2:]
|
||||
assert acc.op is Ops.INDEX
|
||||
buf = acc.src[0]
|
||||
ret = x.substitute({buf: buf.rtag().after(*ranges)}).substitute({buf.rtag(): buf})
|
||||
base, acc, ranges = ret.src[0], ret.src[1], ret.src[2:]
|
||||
return buf.after(acc.store(base).end(*ranges)).index(acc.src[1])
|
||||
|
||||
pm_reduce = PatternMatcher([
|
||||
# REDUCE -> DEFINE_ACC+ASSIGN
|
||||
# REDUCE -> DEFINE_ACC+STORE
|
||||
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
|
||||
# SCAN -> STORE
|
||||
(UPat(Ops.SCAN, name="x"), scan_to_store),
|
||||
# tensor core built in accumulate
|
||||
(UPat(Ops.WMMA, name="wmma") + UPat.var("add"),
|
||||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||
|
||||
@@ -54,7 +54,7 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
||||
if x.op in {Ops.BUFFERIZE, Ops.INDEX}: return None
|
||||
if x.op is Ops.AFTER and x.src[1].op is Ops.KERNEL: return None
|
||||
new_srcs = []
|
||||
for s in x.src:
|
||||
for i,s in enumerate(x.src):
|
||||
new_src = s
|
||||
if s.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.MSTACK, Ops.MSELECT} or (s.op is Ops.AFTER and s.src[1].op is Ops.KERNEL):
|
||||
if x in ctx.range_map: new_src = new_src.index(*ctx.range_map[x][0])
|
||||
@@ -65,7 +65,9 @@ def create_bufferize_and_index_based_on_ranges(ctx:IndexingContext, x:UOp):
|
||||
# None in the device assigns it a number later
|
||||
opts = BufferizeOpts(device=s.device) if len(ctx.range_map[s][1]) == len(realized_ranges) else BufferizeOpts(None, AddrSpace.LOCAL)
|
||||
new_src = UOp(Ops.BUFFERIZE, s.dtype, src=(new_src,)+closed_ranges, arg=opts, tag=s.tag if opts.addrspace == AddrSpace.GLOBAL else None)
|
||||
if x in ctx.range_map: new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][0]) if i in realized_ranges])
|
||||
if x in ctx.range_map:
|
||||
# for scan we use the output ranges on the 2nd arg
|
||||
new_src = new_src.index(*[r for i,r in enumerate(ctx.range_map[x][int(x.op is Ops.SCAN and i == 1)]) if i in realized_ranges])
|
||||
new_srcs.append(new_src)
|
||||
# NOTE: do we need this?
|
||||
return x.replace(src=tns) if x.src != (tns:=tuple(new_srcs)) else None
|
||||
@@ -84,6 +86,13 @@ def convert_reduce_axis_to_reduce_with_ranges(ctx:IndexingContext, x:UOp):
|
||||
ctx.range_map[ret] = ctx.range_map[x]
|
||||
return ret
|
||||
|
||||
def add_ranges_to_scan(ctx:IndexingContext, x:UOp):
|
||||
if x not in ctx.range_map: return None
|
||||
new_ranges = [r for r,ar in zip(*ctx.range_map[x]) if r is not ar and r not in x.src]
|
||||
ret = x.replace(src=x.src+tuple(new_ranges))
|
||||
ctx.range_map[ret] = ctx.range_map[x]
|
||||
return ret
|
||||
|
||||
def remove_movement_op_after_rangeify(ctx:IndexingContext, x:UOp):
|
||||
if x in ctx.range_map or x.src[0].op is Ops.INDEX: return x.src[0]
|
||||
|
||||
@@ -97,6 +106,8 @@ def add_third_op_to_assign_to_track_shape(ctx:IndexingContext, assign:UOp):
|
||||
pm_apply_rangeify = PatternMatcher([
|
||||
# REDUCE_AXIS -> REDUCE
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), convert_reduce_axis_to_reduce_with_ranges),
|
||||
# SCAN -> SCAN (with new ranges)
|
||||
(UPat(Ops.SCAN, name="x"), add_ranges_to_scan),
|
||||
# PAD -> WHERE
|
||||
(UPat(Ops.PAD, name="x"), convert_pad_to_where_to_keep_behavior_local),
|
||||
# add third op to assign
|
||||
@@ -244,6 +255,9 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
# REDUCE_AXIS creates ranges for the axes it is reducing
|
||||
if x.op is Ops.REDUCE_AXIS:
|
||||
rngs = tuple(rctx.new_range(s, axistype=AxisType.REDUCE) if i in x.arg[1] else r for i,(r,s) in enumerate(zip(rngs, x.src[0].shape)))
|
||||
if x.op is Ops.SCAN:
|
||||
rngs = tuple(rctx.new_range(s, axistype=AxisType.REDUCE) if resolve(x.src[1].shape[i] == 1) else r \
|
||||
for i,(r,s) in enumerate(zip(rngs, x.src[0].shape)))
|
||||
|
||||
if debug:
|
||||
realized_ranges = rctx.realize_map.get(x, None)
|
||||
|
||||
@@ -20,7 +20,7 @@ axis_letters = {AxisType.GLOBAL: "g", AxisType.THREAD: "t", AxisType.LOCAL: "l",
|
||||
axis_colors = {AxisType.GLOBAL: "blue", AxisType.THREAD: "BLUE", AxisType.LOCAL: "cyan", AxisType.WARP: "CYAN", AxisType.LOOP: "WHITE",
|
||||
AxisType.UPCAST: "yellow", AxisType.GROUP_REDUCE: "RED", AxisType.REDUCE: "red", AxisType.UNROLL: "magenta"}
|
||||
|
||||
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1}
|
||||
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1, Ops.SCAN: 2}
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
||||
|
||||
@@ -17,7 +17,7 @@ from tinygrad.dtype import dtypes
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
Ops.DEFINE_GLOBAL:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.SCAN: "#FF7B7B",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||
|
||||
Reference in New Issue
Block a user