mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into delete_slow_rangeify
This commit is contained in:
@@ -42,7 +42,6 @@ import struct
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
# allocate some buffers + load in values
|
||||
out = Buffer(DEVICE, 1, dtypes.int32).allocate()
|
||||
@@ -51,13 +50,14 @@ b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struc
|
||||
# NOTE: a._buf is the same as the return from cpu.allocator.alloc
|
||||
|
||||
# describe the computation
|
||||
idx = UOp.const(dtypes.index, 0)
|
||||
buf_1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 1)
|
||||
buf_2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 2)
|
||||
ld_1 = UOp(Ops.LOAD, dtypes.int32, (buf_1.view(ShapeTracker.from_shape((1,))),))
|
||||
ld_2 = UOp(Ops.LOAD, dtypes.int32, (buf_2.view(ShapeTracker.from_shape((1,))),))
|
||||
ld_1 = UOp(Ops.LOAD, dtypes.int32, (buf_1.index(idx),))
|
||||
ld_2 = UOp(Ops.LOAD, dtypes.int32, (buf_2.index(idx),))
|
||||
alu = ld_1 + ld_2
|
||||
output_buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
|
||||
st_0 = UOp(Ops.STORE, dtypes.void, (output_buf.view(ShapeTracker.from_shape((1,))), alu))
|
||||
st_0 = UOp(Ops.STORE, dtypes.void, (output_buf.index(idx), alu))
|
||||
s = UOp(Ops.SINK, dtypes.void, (st_0,))
|
||||
|
||||
# convert the computation to a "linearized" format (print the format)
|
||||
|
||||
@@ -149,6 +149,7 @@ class TestFloat4(unittest.TestCase):
|
||||
|
||||
assert TestFloat4.count_float4(uops) == (1, 1)
|
||||
|
||||
@unittest.skip("Ops.VIEW no longer exists")
|
||||
def test_half4_load_unrolled(self):
|
||||
# from llama 7B shard 4 gpus
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
|
||||
@@ -6,8 +6,6 @@ from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.codegen.gpudims import get_grouped_dims
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp
|
||||
from tinygrad.device import Device, Buffer, is_dtype_supported
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner, get_program
|
||||
from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT, RANGEIFY
|
||||
@@ -38,24 +36,6 @@ class TestLinearizer(unittest.TestCase):
|
||||
np.testing.assert_equal(a.numpy(), ta)
|
||||
np.testing.assert_equal(b.numpy(), tb)
|
||||
|
||||
def test_multioutput(self):
|
||||
dtype, st = dtypes.int, ShapeTracker.from_shape((8,))
|
||||
g0, g1, g2, g3 = [UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), arg=i) for i in range(4)]
|
||||
a = UOp(Ops.LOAD, dtype, src=(g2.view(st),))
|
||||
b = UOp(Ops.LOAD, dtype, src=(g3.view(st),))
|
||||
out0 = UOp(Ops.STORE, dtypes.void, src=(g0.view(st), a + b))
|
||||
out1 = UOp(Ops.STORE, dtypes.void, src=(g1.view(st), a * b))
|
||||
sink = UOp(Ops.SINK, src=(out0, out1))
|
||||
|
||||
a_t = Tensor.full(st.shape, 2).contiguous().realize()
|
||||
b_t = Tensor.full(st.shape, 3).contiguous().realize()
|
||||
helper_linearizer_ast(sink, [a_t, b_t], wanna_output=[a_t.numpy()+b_t.numpy(), a_t.numpy()*b_t.numpy()])
|
||||
uops = get_program(sink, opts=[]).uops
|
||||
stores = [u for u in uops if u.op is Ops.STORE]
|
||||
mutable_bufs = dedup(flatten([[x for x in u.src[0].toposort() if x.op is Ops.DEFINE_GLOBAL] for u in stores]))
|
||||
assert len(mutable_bufs) == len(stores) == 2
|
||||
self.assertSetEqual(set([u.arg for u in mutable_bufs]), set([0,1]))
|
||||
|
||||
def _test_no_nested_ranges(self, lins, skip=None):
|
||||
for l in lins:
|
||||
range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_REG])
|
||||
@@ -437,41 +417,6 @@ class TestLinearizer(unittest.TestCase):
|
||||
# the global store doesn't change
|
||||
assert stores[1].src[1].dtype == dtypes.float
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
def test_skip_unmatching_upcasts(self):
|
||||
Tensor.manual_seed(0)
|
||||
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=0, src=())
|
||||
c1 = c0.view(ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),)))
|
||||
c2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=1, src=())
|
||||
c3 = c2.view(ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)))
|
||||
c4 = c3.load()
|
||||
c5 = c1.store(c4)
|
||||
ast = c5.sink()
|
||||
opt = [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16),
|
||||
Opt(op=OptOps.LOCAL, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=3, arg=2)]
|
||||
helper_linearizer_ast(ast, [Tensor.randn(240*40).realize()], opts=[opt])
|
||||
out = [u for u in get_program(ast, opts=opt).uops if u.op is Ops.STORE][0]
|
||||
assert out.src[1].op is Ops.VECTORIZE and out.src[1].dtype == dtypes.float.vec(4)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
|
||||
def test_skip_unmatching_upcasts_with_gep(self):
|
||||
Tensor.manual_seed(0)
|
||||
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=0, src=())
|
||||
c1 = c0.view(ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)))
|
||||
c2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=1, src=())
|
||||
c3 = c2.view(ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),)))
|
||||
c4 = c3.load()
|
||||
c5 = c1.store(c4)
|
||||
ast = c5.sink()
|
||||
opt = [Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=1, arg=8),
|
||||
Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8),
|
||||
Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=2)]
|
||||
helper_linearizer_ast(ast, [Tensor.randn(8*32).realize()], opts=[opt])
|
||||
out = [u for u in get_program(ast).uops if u.op is Ops.STORE][0]
|
||||
assert out.src[1].op is Ops.VECTORIZE and out.src[1].dtype.count != 1
|
||||
|
||||
# *** helpers ***
|
||||
|
||||
def helper_realized_ast(r:Tensor|list[Tensor]) -> tuple[UOp, list[Buffer]]:
|
||||
|
||||
@@ -32,6 +32,7 @@ class TestLinearizerFailure(unittest.TestCase):
|
||||
|
||||
class TestLinearizerDumb(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
|
||||
@unittest.skip("Ops.VALID no longer exists")
|
||||
def test_max_simplify_and_cancel(self):
|
||||
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1000), arg=0, src=())
|
||||
c1 = c0.view(ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))
|
||||
@@ -54,6 +55,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
|
||||
# this was a bug in embedding, someday we should fold this anyway
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), f"half dtype not supported on {Device.DEFAULT}")
|
||||
@unittest.skip("UOp.view is no longer supported")
|
||||
def test_llama_embedding(self):
|
||||
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(4096), arg=0, src=())
|
||||
c1 = c0.view(ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)))
|
||||
|
||||
@@ -1954,8 +1954,7 @@ class TestSchedule(unittest.TestCase):
|
||||
a = Tensor([1,2,3,4]).realize()
|
||||
for _ in range(24): a = a + a
|
||||
new_uop = a.reshape(4,1).realize().uop
|
||||
self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1)))
|
||||
self.assertEqual(swizzle_cnt(new_uop), 0)
|
||||
assert new_uop.base.op is Ops.BUFFER
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "NV", "crashes on NV CI")
|
||||
def test_limit_bufs_with_var(self):
|
||||
@@ -1981,9 +1980,6 @@ class TestSchedule(unittest.TestCase):
|
||||
sched = z.schedule()
|
||||
self.assertEqual(len(sched), kcount+1)
|
||||
|
||||
def swizzle_cnt(u:UOp) -> int:
|
||||
return len([x for x in u.toposort() if x.op is Ops.VIEW and len(x.src) != 0 and x.src[0].op not in {Ops.BUFFER, Ops.DEFINE_GLOBAL, Ops.ASSIGN}])
|
||||
|
||||
class TestSwizzle(unittest.TestCase):
|
||||
def test_swizzle_simple(self):
|
||||
Tensor.manual_seed(0)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from typing import Optional, Any
|
||||
import unittest, math
|
||||
import numpy as np
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View # noqa F401
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.helpers import CI, DEBUG, getenv, Timing
|
||||
from tinygrad.dtype import dtypes, DType, AddrSpace
|
||||
@@ -492,15 +490,6 @@ class TestUOpMethod(unittest.TestCase):
|
||||
self.assertIs(x.replace(arg=None).arg, None)
|
||||
with self.assertRaises(AssertionError): x.replace(field="a")
|
||||
|
||||
def test_device(self):
|
||||
x = UOp(Ops.VIEW, dtypes.int, (UOp.new_buffer(Device.DEFAULT, 1, dtypes.int), UOp.const(dtypes.int, 1)), ShapeTracker.from_shape(()))
|
||||
self.assertEqual(x.device, Device.DEFAULT)
|
||||
# NOTE: CONST doesn't have device
|
||||
buffer, const = x.src
|
||||
self.assertEqual(buffer.device, Device.DEFAULT)
|
||||
self.assertEqual(const._device, None)
|
||||
with self.assertRaises(AssertionError): const.device
|
||||
|
||||
class TestUOpStr(unittest.TestCase):
|
||||
def test_uop_str(self):
|
||||
a = UOp(Ops.CONST, dtypes.float, (), 2.0) + UOp(Ops.CONST, dtypes.float, (), 3.0)
|
||||
|
||||
@@ -6,7 +6,6 @@ from tinygrad.uop.ops import UPat, Ops, UOp
|
||||
realized_pattern = UPat(Ops.BUFFER)
|
||||
# after realization, base tensor uops become RESHAPE(BUFFER)
|
||||
buffer_view_pattern = UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),))
|
||||
const_pattern = UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),),)))
|
||||
def is_pattern_uop(u:UOp, pat:UPat): assert pat.match(u, {}), f"{u}\nis not\n{pat}"
|
||||
def is_pattern(ten:Tensor, pat:UPat): is_pattern_uop(ten.uop, pat)
|
||||
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import unittest
|
||||
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.helpers import DEBUG, RANGEIFY
|
||||
from tinygrad.uop.ops import UOp, Ops, print_uops
|
||||
from tinygrad.uop.spec import type_verify, ast_spec, tensor_uop_spec
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.engine.realize import get_program
|
||||
from tinygrad.device import Device
|
||||
|
||||
class InvalidASTException(Exception): pass
|
||||
def helper_test_verify_ast(*stores:UOp):
|
||||
sink = UOp(Ops.SINK, dtypes.void, stores)
|
||||
if DEBUG >= 3:
|
||||
for op in stores: print(op)
|
||||
try: type_verify(list(sink.toposort()), ast_spec)
|
||||
except RuntimeError as e: raise InvalidASTException(e.args)
|
||||
program = get_program(sink, Device[Device.DEFAULT].renderer)
|
||||
|
||||
if DEBUG >= 6: print_uops(program.uops)
|
||||
if DEBUG >= 4: print(program.src)
|
||||
|
||||
class TestUOpSpec(unittest.TestCase):
|
||||
def test_tiny_add(self):
|
||||
dtype = dtypes.int
|
||||
buf_0 = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 0)
|
||||
buf_1 = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 1)
|
||||
buf_2 = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 2)
|
||||
a = UOp(Ops.LOAD, dtype, (buf_1.view(ShapeTracker.from_shape((32, 1))),))
|
||||
b = UOp(Ops.LOAD, dtype, (buf_2.view(ShapeTracker.from_shape((32, 1))),))
|
||||
store = UOp(Ops.STORE, dtypes.void, (buf_0.view(ShapeTracker.from_shape((32, 1))), a+b))
|
||||
helper_test_verify_ast(store)
|
||||
|
||||
def test_no_implicit_broadcasting(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1].view(ShapeTracker.from_shape((4, 32))),))
|
||||
b = a + UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.MAX, (1,)))
|
||||
st = UOp(Ops.STORE, dtypes.void, (bufs[0].view(ShapeTracker.from_shape((4, 32))), b))
|
||||
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
|
||||
|
||||
def test_shrink_ok(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1].view(ShapeTracker((View((32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),))),))
|
||||
b = UOp(Ops.LOAD, dtypes.float, (bufs[1].view(ShapeTracker((View((32, 32), strides=(0, 1), offset=0, mask=None, contiguous=False),))),))
|
||||
st = UOp.store(bufs[0].view(ShapeTracker.from_shape((32, 32))), a+b)
|
||||
helper_test_verify_ast(st)
|
||||
|
||||
def test_reduce_store(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1].view(ShapeTracker.from_shape((32, 1))),))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
|
||||
st = UOp.store(bufs[0].view(ShapeTracker.from_shape((32, 1))), r)
|
||||
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
|
||||
|
||||
def test_reduce_add_store(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||
a = UOp(Ops.LOAD, dtypes.float, (bufs[1].view(ShapeTracker.from_shape((32, 1))),))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
|
||||
st = UOp.store(bufs[0].view(ShapeTracker.from_shape((32, 1))), r+a)
|
||||
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
|
||||
|
||||
def test_assert_swizzle(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
a = UOp(Ops.LOAD, dtypes.float, (buf.view(ShapeTracker.from_shape((32, 1))),))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
|
||||
st = UOp.store(buf.view(ShapeTracker.from_shape((32, 1))), r.view(r.st.expand((32, 1)))+a)
|
||||
with self.assertRaisesRegex(InvalidASTException, "UOp verification failed"): helper_test_verify_ast(st)
|
||||
|
||||
def test_const_view_always_valid(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
a = UOp.const(dtypes.int, 0).replace(src=(UOp(Ops.VIEW, dtypes.void, (), ShapeTracker.from_shape(())),))
|
||||
st = UOp.store(buf.view(ShapeTracker.from_shape(())), a.cast(dtypes.float))
|
||||
helper_test_verify_ast(st)
|
||||
|
||||
@unittest.skipIf(RANGEIFY, "RANGEIFY does not push views")
|
||||
def test_assert_masked_view_in_const(self):
|
||||
t = Tensor(6).uop
|
||||
a = t.replace(src=(t.src[0].replace(arg=t.st.reshape((1,)).pad(((0, 1),))),))
|
||||
with self.assertRaisesRegex(RuntimeError, "UOp verification failed"):
|
||||
type_verify([a], tensor_uop_spec)
|
||||
|
||||
class TestUOpSink(unittest.TestCase):
|
||||
def test_0(self):
|
||||
s = UOp.sink()
|
||||
self.assertEqual(len(s.src), 0)
|
||||
|
||||
def test_1(self):
|
||||
a = UOp.const(dtypes.int, 0)
|
||||
s1 = UOp.sink(a)
|
||||
s2 = a.sink()
|
||||
self.assertIs(s1, s2)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,6 +1,6 @@
|
||||
# the job of the lowerer is to do indexing
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite, resolve
|
||||
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite
|
||||
|
||||
# ***** indexing *****
|
||||
|
||||
@@ -15,8 +15,8 @@ def shape_to_idx(s, axis_types, start=0):
|
||||
|
||||
def get_index(ast:UOp) -> IndexContext:
|
||||
axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else ()
|
||||
if len(ast.full_shape) != len(axis_types) and ast.st is not None:
|
||||
axis_types = tuple([AxisType.REDUCE if resolve(s != fs) else AxisType.LOOP for s,fs in zip(ast.shape, ast.full_shape)])
|
||||
#if len(ast.full_shape) != len(axis_types) and ast.st is not None:
|
||||
# axis_types = tuple([AxisType.REDUCE if resolve(s != fs) else AxisType.LOOP for s,fs in zip(ast.shape, ast.full_shape)])
|
||||
return IndexContext(axis_types, [], 0)
|
||||
|
||||
# ***** lowering (given index) *****
|
||||
@@ -26,29 +26,6 @@ def subblock(ctx: IndexContext, full_new_idx: list[UOp], src: UOp):
|
||||
ctx.start = lc.start
|
||||
return graph_rewrite(src, pm_lowerer, lc, name="subblock", bottom_up=True)
|
||||
|
||||
def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
||||
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
|
||||
full_new_idx = list(ctx.idxs)
|
||||
for a in x.axis_arg: full_new_idx[a] = new_idxs[a]
|
||||
ret = subblock(ctx, full_new_idx, x.src[0])
|
||||
return UOp(Ops.REDUCE, x.dtype, (ret,)+tuple([full_new_idx[i] for i in x.axis_arg]), x.arg[0])
|
||||
|
||||
def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
|
||||
# TODO: reenable after REDUCE_AXIS is fixed
|
||||
#assert x.src[1].shape == x.src[0].shape, f"shape mismatch on store {x.src[1].shape} != {x.src[0].shape}"
|
||||
|
||||
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
|
||||
idx = x.st_arg.to_valid_uop(new_idxs)
|
||||
used_idxs = [x for x in idx.toposort() if x in new_idxs]
|
||||
real_new_idxs = []
|
||||
for i in range(len(x.src[0].shape)):
|
||||
if new_idxs[i] in used_idxs or len(ctx.idxs) <= i: real_new_idxs.append(new_idxs[i])
|
||||
else: real_new_idxs.append(ctx.idxs[i])
|
||||
|
||||
stored = subblock(ctx, real_new_idxs, x.src[1])
|
||||
used_ranges = [x for x in used_idxs if x.op is Ops.RANGE]
|
||||
return buf.index(idx).store(stored, *used_ranges)
|
||||
|
||||
def fixup_wmma(ctx:IndexContext, x:UOp):
|
||||
if x.tag is not None: return None
|
||||
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
|
||||
@@ -63,21 +40,6 @@ def fixup_wmma(ctx:IndexContext, x:UOp):
|
||||
return x.replace(src=srcs, arg=x.arg[:-2]+(new_x_arg_m2, new_x_arg_m1), tag=1)
|
||||
|
||||
pm_lowerer = PatternMatcher([
|
||||
# TODO: remove these hacks
|
||||
# hack for old style CONST(VIEW) (now it's just VIEW(CONST))
|
||||
(UPat((Ops.DEFINE_VAR, Ops.CONST), src=(UPat(Ops.VIEW, name="v"),), name="c"), lambda c,v: c.replace(src=()).view(v.arg)),
|
||||
# hack for old style VALID (now it's just VIEW(CONST))
|
||||
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c"), UPat(Ops.CONST, arg=0)), lambda c,v: c.replace(src=()).view(v.arg)),
|
||||
|
||||
# consts and loads
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),), name="view"),
|
||||
lambda ctx,view,c: c if all(x.mask is None for x in view.arg.views) else view.arg.to_valid_uop(ctx.idxs).get_valid().where(c, c.const_like(0))),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("buf").view(),), allow_any_len=True, name="x"),
|
||||
lambda ctx,buf,x: UOp(Ops.LOAD, x.dtype, (buf.index(x.st_arg.to_valid_uop(ctx.idxs)),)+x.src[1:])),
|
||||
|
||||
# reduce/view_const
|
||||
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
|
||||
(UPat(Ops.STORE, src=(UPat.var("buf").view(),), allow_any_len=True, name="x"), lower_store),
|
||||
(UPat(Ops.WMMA, name="x"), fixup_wmma),
|
||||
|
||||
# axis fixups for WMMA
|
||||
|
||||
@@ -27,13 +27,13 @@ pm_quant = symbolic+PatternMatcher([
|
||||
(UPat.var("x")*UPat.cvar("c1", dtype=dtypes.floats) + UPat.var("y")*UPat.cvar("c2", dtype=dtypes.floats),
|
||||
lambda x,y,c1,c2: (x+y)*c1 if abs(c1.arg-c2.arg) < 1e-9 else None),
|
||||
# mul 0 * c1 is 0
|
||||
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
|
||||
UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1),
|
||||
#(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
|
||||
# UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1),
|
||||
# mul (with plus) 0 * c1 is 0
|
||||
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
|
||||
(UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int) + \
|
||||
UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
|
||||
lambda ld,v,c1: ld*c1),
|
||||
#(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
|
||||
# (UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int) + \
|
||||
# UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
|
||||
# lambda ld,v,c1: ld*c1),
|
||||
|
||||
# const push through add
|
||||
((UPat.var("x")*UPat.cvar("c1") + UPat.var("y")*UPat.cvar("c2")) * UPat.cvar("c3"), lambda x,y,c1,c2,c3: (x*c1*c3) + (y*c2*c3)),
|
||||
@@ -64,4 +64,4 @@ pm_quant = symbolic+PatternMatcher([
|
||||
lambda v1,v2,c1,r: r.replace(src=(v1*v2,)) + r.replace(src=(c1*v2,))),
|
||||
(UPat(Ops.REDUCE_AXIS, src=((UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")),), name="r"),
|
||||
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,)) + r.replace(src=(c1*c2,))),
|
||||
])
|
||||
])
|
||||
|
||||
@@ -29,12 +29,6 @@ class Ops(FastEnum):
|
||||
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
|
||||
MULTI = auto() # MULTI is really a movement op
|
||||
|
||||
# view is what all movement ops become
|
||||
VIEW = auto()
|
||||
|
||||
# TODO: remove VALID with the VIEW(CONST(DEVICE)) refactor
|
||||
VALID = auto()
|
||||
|
||||
# TODO: unify these ops into the levels of the memory hierarchy. depends on ASSIGN is STORE
|
||||
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_REG = auto() # noqa: E702
|
||||
|
||||
@@ -96,7 +90,7 @@ class GroupOp:
|
||||
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
|
||||
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
|
||||
|
||||
Buffer = {Ops.LOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
|
||||
Buffer = {Ops.LOAD, Ops.STORE, Ops.CONST, Ops.DEFINE_VAR}
|
||||
Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKSTART}
|
||||
|
||||
# BinaryOps that can be flipped
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.uop import Ops, GroupOp
|
||||
from tinygrad.uop.mathtraits import MathTrait
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType
|
||||
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
|
||||
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, RANGEIFY, VIZ, SPEC
|
||||
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC
|
||||
from tinygrad.helpers import strip_parens
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -186,8 +186,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.BARRIER: return None
|
||||
if self.op in GroupOp.Block: return None
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
# VIEW and MovementOps define a new ShapeTracker from the arg
|
||||
if self.op is Ops.VIEW: return self.arg
|
||||
# MovementOps define a new ShapeTracker from the arg
|
||||
if self.op is Ops.BUFFERIZE: return ShapeTracker.from_shape(tuple([int(r.vmax+1) for r in self.src[1:]]))
|
||||
# allow reshape from nothing
|
||||
if self.op is Ops.RESHAPE and self.src[0].st is None: return ShapeTracker.from_shape(self.arg)
|
||||
@@ -198,7 +197,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.STORE and self.dtype is not dtypes.void: return self.src[0].src[0].st
|
||||
# BufferOps and ASSIGN flow ShapeTracker from a direct edge
|
||||
if self.op in {Ops.STORE, Ops.ASSIGN, Ops.LOAD}: return self.src[0].st
|
||||
if self.op in GroupOp.Buffer: return views[0] if (views:=[x.st for x in self.src if x.op is Ops.VIEW]) else None
|
||||
|
||||
# BUFFER/BUFFER_VIEW and KERNEL only have a size
|
||||
if self.op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return ShapeTracker.from_shape((self.size,))
|
||||
@@ -229,12 +227,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
case _: shape = src_sts[0].shape
|
||||
return ShapeTracker.from_shape(shape)
|
||||
|
||||
@functools.cached_property
|
||||
def full_shape(self) -> tuple[sint, ...]:
|
||||
if self.op is Ops.VIEW: return self.shape
|
||||
# NOTE: if a parent doesn't have st its full_shape is empty
|
||||
parent_shapes = [x.full_shape for x in self.src]
|
||||
return tuple(smax(x) for x in itertools.zip_longest(*parent_shapes, fillvalue=1))
|
||||
@property
|
||||
def shape(self) -> tuple[sint, ...]:
|
||||
assert self.st is not None, f"{self.op} doesn't have a shape"
|
||||
@@ -345,17 +337,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
|
||||
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
||||
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype), src=() if src is None else (src,))
|
||||
if RANGEIFY:
|
||||
# VIEW on const is no longer supported in RANGEIFY
|
||||
if device is not None: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
|
||||
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
|
||||
else:
|
||||
if shape is not None:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
ret = ret.replace(src=(UOp(Ops.VIEW, dtypes.void, (), ShapeTracker.from_shape(shape, (0,)*len(shape))),))
|
||||
if device is not None:
|
||||
if shape is not None: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device).view(unwrap(ret.st)),))
|
||||
else: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
|
||||
if device is not None: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
|
||||
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
|
||||
return ret
|
||||
@staticmethod
|
||||
def range(end:sint, *arg):
|
||||
@@ -455,10 +438,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
@property
|
||||
def base(self) -> UOp:
|
||||
if (self.op is Ops.VIEW and len(self.src) != 0) or self.op in GroupOp.Movement: return self.src[0].base
|
||||
if self.op in GroupOp.Movement: return self.src[0].base
|
||||
if self.op is Ops.MULTI: return self.src[0].base # MULTI is really a VIEW
|
||||
return self
|
||||
def view(self, new_st:ShapeTracker) -> UOp: return UOp(Ops.VIEW, self.dtype, (self,), new_st)
|
||||
|
||||
def _mop(self, op:Ops, arg) -> UOp:
|
||||
ret = UOp(op, self.dtype, (self,), arg)
|
||||
@@ -571,8 +553,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
all_vars = set([x for x in self.toposort() if x.op is Ops.DEFINE_VAR])
|
||||
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
|
||||
def variables(self) -> list[Variable]:
|
||||
st_vars: list[set[Variable]] = [x.arg.vars() for x in self.toposort() if x.op is Ops.VIEW]
|
||||
return sorted(set.union(*st_vars, set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()])), key=lambda v: v.arg)
|
||||
return sorted(set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
|
||||
|
||||
# *** uop symbolic stuff ***
|
||||
|
||||
@@ -789,7 +770,6 @@ class UPat(MathTrait):
|
||||
# copied from UOp
|
||||
def sink(self, *srcs:UPat|None, **kwargs): return UPat(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def index(self, idx:UPat, valid:UPat|None=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
||||
def view(self, st=None, **kwargs): return UPat(Ops.VIEW, self.dtype, (self,), st, **kwargs)
|
||||
def cast(self, dtype=None, **kwargs): return UPat(Ops.CAST, dtype, (self,), **kwargs)
|
||||
def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,))
|
||||
def gep(self, i:int|None=None, **kwargs): return UPat(Ops.GEP, None, (self,), (i,) if i is not None else None, **kwargs)
|
||||
@@ -1145,7 +1125,6 @@ renderer = PatternMatcher([
|
||||
(UPat(Ops.MULACC, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
|
||||
(UPat(Ops.WHERE, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
|
||||
(UPat(set(syms.keys()), src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.op]}{x.src[1].arg})")),
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.view({x.arg})")),
|
||||
(UPat((Ops.INDEX, Ops.BUFFERIZE), name="x"), lambda x:
|
||||
UOp(Ops.NOOP, arg=''.join([f"[{strip_parens(y.arg)}]" for y in x.src[1:]])) if all(y.op is Ops.NOOP for y in x.src[1:]) else None),
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.NOOP), name="x"),
|
||||
@@ -1172,8 +1151,6 @@ pm_pyrender = PatternMatcher([
|
||||
arg=f"{x.src[0].arg}.{sugar[x.op]}({', '.join([y.arg for y in x.src[1:]] + ([f'arg={str(x.arg)}'] if x.arg is not None else []))})")),
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.NOOP),), name="x"),
|
||||
lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.f({x.op}, arg=({', '.join([str(y) for y in x.arg])}))")),
|
||||
(UPat(Ops.VALID, src=(UPat(Ops.NOOP),), name="x"),
|
||||
lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.f({x.op}, dtype=dtypes.bool)")),
|
||||
])
|
||||
|
||||
@Context(SPEC=0)
|
||||
@@ -1182,8 +1159,8 @@ def pyrender(ast:UOp) -> list[str]:
|
||||
to_render = set()
|
||||
for u in ast.toposort():
|
||||
if u.op is Ops.STORE: to_render.add(u.src[1])
|
||||
if len(cmap[u]) == 1 and u.op not in {Ops.DEFINE_GLOBAL, Ops.VIEW, Ops.LOAD} or u.op in {Ops.CONST}: continue
|
||||
if u.op in {Ops.SINK, Ops.VIEW}:
|
||||
if len(cmap[u]) == 1 and u.op not in {Ops.DEFINE_GLOBAL, Ops.LOAD} or u.op in {Ops.CONST}: continue
|
||||
if u.op in {Ops.SINK}:
|
||||
for s in u.src: to_render.add(s)
|
||||
to_render.add(u)
|
||||
ret: list[str] = []
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from typing import cast, Callable
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, graph_rewrite, AxisType
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid
|
||||
from tinygrad.helpers import all_same, prod, DEBUG, IGNORE_OOB, Context, cpu_profile, RANGEIFY
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.helpers import all_same, prod, DEBUG, IGNORE_OOB, Context, cpu_profile
|
||||
try:
|
||||
import z3
|
||||
# older versions of z3 dont have some operators like & overloaded
|
||||
@@ -64,8 +63,6 @@ buffer_spec = PatternMatcher([
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"),
|
||||
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.MSTACK, src=UPat(Ops.BUFFER)),)), lambda: True),
|
||||
# allow VIEW here. TODO: what views specifically are allowed? does this mess with gradient?
|
||||
(UPat(Ops.VIEW), lambda: True),
|
||||
])
|
||||
|
||||
assign_spec = PatternMatcher([
|
||||
@@ -92,17 +89,10 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
|
||||
# this is fine as long as it's a realized buffer or const and base dtypes match.
|
||||
((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base \
|
||||
and x.base.op in {Ops.BUFFER,Ops.ASSIGN,Ops.CONST})),
|
||||
(UPat(Ops.VIEW, src=(UPat.var("x"),)), lambda x: x.base.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.CONST, Ops.DEVICE}),
|
||||
|
||||
# Tensor variable bindings
|
||||
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True),
|
||||
|
||||
# Tensor const has a device and an unmasked ShapeTracker of stride 0
|
||||
# NOTE: variables in shape can cause multiple views in this ShapeTracker and other issues, see TestSymbolicJit.test_ones_sum
|
||||
# TODO: remove after rangeify is default
|
||||
(UPat(Ops.CONST, src=(UPat.any(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="st"),
|
||||
UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.BIND)), name="st")),)),
|
||||
lambda st: len(st.st.views) == 1 and all(v.mask is None for v in st.st.views)),
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),)), lambda: True),
|
||||
|
||||
# DETACH and CONTIGUOUS change how we interpret the source UOp
|
||||
@@ -167,20 +157,8 @@ spec = PatternMatcher([
|
||||
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
|
||||
(UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)),
|
||||
|
||||
(UPat(Ops.VIEW, dtypes.void, src=(), name="x"), lambda x: isinstance(x.arg, ShapeTracker)),
|
||||
(UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"),
|
||||
lambda x,src: isinstance(x.arg, ShapeTracker) and src.op is not Ops.STORE and x.dtype.base == src.dtype.base),
|
||||
|
||||
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
|
||||
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
|
||||
|
||||
# early LOAD has a <bufview, store?>
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)),)), lambda: True),
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)), UPat(Ops.STORE))), lambda: True),
|
||||
|
||||
# early STORE has a <bufview, val>
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)), UPat())), lambda: True),
|
||||
|
||||
# **** new style load/store ****
|
||||
|
||||
# make sure all index dtypes have been lowered
|
||||
@@ -243,20 +221,12 @@ spec = PatternMatcher([
|
||||
# *** this is the UOp AST spec ***
|
||||
|
||||
ast_spec = PatternMatcher([
|
||||
# VIEW can only exist in the edges
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL),))), lambda: True),
|
||||
(UPat(Ops.VIEW, name="view"), lambda view: len(view.src) == 0),
|
||||
# all parent UOps must have the same shape
|
||||
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: all_same([x.shape for x in root.src if x.st is not None])),
|
||||
])
|
||||
|
||||
# *** this spec should match all UOps ever created ***
|
||||
|
||||
full_non_rangeify_spec = PatternMatcher([]) if RANGEIFY else PatternMatcher([
|
||||
# in non rangeify const can still have a View, and sometimes a FUSE while propagating
|
||||
(UPat((Ops.VIEW, Ops.FUSE)).f(Ops.CONST), lambda: True),
|
||||
])
|
||||
|
||||
full_spec = PatternMatcher([
|
||||
# SENTINEL should never be in the graph
|
||||
(UPat(Ops.SENTINEL), lambda: False),
|
||||
@@ -310,7 +280,7 @@ full_spec = PatternMatcher([
|
||||
(UPat(Ops.DEFINE_VAR), lambda: True),
|
||||
# reshape on STORE
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.STORE),)), lambda: True),
|
||||
])+full_non_rangeify_spec+tensor_uop_spec+spec
|
||||
])+tensor_uop_spec+spec
|
||||
|
||||
# ***** uop helpers *****
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from tinygrad.codegen.opt import axis_colors
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_REG: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
|
||||
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
|
||||
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
|
||||
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
|
||||
@@ -62,14 +62,9 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
for u in (toposort:=x.toposort()):
|
||||
# always exclude DEVICE/CONST/UNIQUE
|
||||
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE} and u is not x: excluded.add(u)
|
||||
# only exclude CONST VIEW source if it has no other children in the graph
|
||||
if u.op is Ops.CONST and u.st is not None: excluded.update(u.src)
|
||||
for u in toposort:
|
||||
if u in excluded: continue
|
||||
argst = codecs.decode(str(u.arg), "unicode_escape")
|
||||
if u.op is Ops.VIEW:
|
||||
argst = ("\n".join([f"{shape_to_str(v.shape)} / {shape_to_str(v.strides)}"+("" if v.offset == 0 else f" / {srender(v.offset)}")+
|
||||
(f"\nMASK {mask_to_str(v.mask)}" if v.mask is not None else "") for v in unwrap(u.st).views]))
|
||||
if u.op in GroupOp.Movement: argst = (mask_to_str if u.op in {Ops.SHRINK, Ops.PAD} else shape_to_str)(u.arg)
|
||||
label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}"
|
||||
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
|
||||
@@ -80,7 +75,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
try:
|
||||
if len(rngs:=u.ranges):
|
||||
label += f"\n({','.join([colored(range_str(x), axis_colors[x.arg[-1]]) for x in sorted(rngs, key=lambda x: x.arg[0:-1])])})"
|
||||
if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None:
|
||||
if u.op not in {Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None:
|
||||
label += f"\n{shape_to_str(u.shape)}"
|
||||
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
|
||||
label += f"\n{u.render()}"
|
||||
|
||||
Reference in New Issue
Block a user