Merge branch 'master' into delete_slow_rangeify

This commit is contained in:
George Hotz
2025-10-08 19:58:16 +08:00
committed by GitHub
14 changed files with 31 additions and 298 deletions

View File

@@ -42,7 +42,6 @@ import struct
from tinygrad.dtype import dtypes
from tinygrad.device import Buffer, Device
from tinygrad.uop.ops import UOp, Ops
from tinygrad.shape.shapetracker import ShapeTracker
# allocate some buffers + load in values
out = Buffer(DEVICE, 1, dtypes.int32).allocate()
@@ -51,13 +50,14 @@ b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struc
# NOTE: a._buf is the same as the return from cpu.allocator.alloc
# describe the computation
idx = UOp.const(dtypes.index, 0)
buf_1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 1)
buf_2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 2)
ld_1 = UOp(Ops.LOAD, dtypes.int32, (buf_1.view(ShapeTracker.from_shape((1,))),))
ld_2 = UOp(Ops.LOAD, dtypes.int32, (buf_2.view(ShapeTracker.from_shape((1,))),))
ld_1 = UOp(Ops.LOAD, dtypes.int32, (buf_1.index(idx),))
ld_2 = UOp(Ops.LOAD, dtypes.int32, (buf_2.index(idx),))
alu = ld_1 + ld_2
output_buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
st_0 = UOp(Ops.STORE, dtypes.void, (output_buf.view(ShapeTracker.from_shape((1,))), alu))
st_0 = UOp(Ops.STORE, dtypes.void, (output_buf.index(idx), alu))
s = UOp(Ops.SINK, dtypes.void, (st_0,))
# convert the computation to a "linearized" format (print the format)

View File

@@ -149,6 +149,7 @@ class TestFloat4(unittest.TestCase):
assert TestFloat4.count_float4(uops) == (1, 1)
@unittest.skip("Ops.VIEW no longer exists")
def test_half4_load_unrolled(self):
# from llama 7B shard 4 gpus
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(

View File

@@ -6,8 +6,6 @@ from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.codegen.gpudims import get_grouped_dims
from tinygrad.uop.ops import UOp, Ops, GroupOp
from tinygrad.device import Device, Buffer, is_dtype_supported
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner, get_program
from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT, RANGEIFY
@@ -38,24 +36,6 @@ class TestLinearizer(unittest.TestCase):
np.testing.assert_equal(a.numpy(), ta)
np.testing.assert_equal(b.numpy(), tb)
def test_multioutput(self):
dtype, st = dtypes.int, ShapeTracker.from_shape((8,))
g0, g1, g2, g3 = [UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), arg=i) for i in range(4)]
a = UOp(Ops.LOAD, dtype, src=(g2.view(st),))
b = UOp(Ops.LOAD, dtype, src=(g3.view(st),))
out0 = UOp(Ops.STORE, dtypes.void, src=(g0.view(st), a + b))
out1 = UOp(Ops.STORE, dtypes.void, src=(g1.view(st), a * b))
sink = UOp(Ops.SINK, src=(out0, out1))
a_t = Tensor.full(st.shape, 2).contiguous().realize()
b_t = Tensor.full(st.shape, 3).contiguous().realize()
helper_linearizer_ast(sink, [a_t, b_t], wanna_output=[a_t.numpy()+b_t.numpy(), a_t.numpy()*b_t.numpy()])
uops = get_program(sink, opts=[]).uops
stores = [u for u in uops if u.op is Ops.STORE]
mutable_bufs = dedup(flatten([[x for x in u.src[0].toposort() if x.op is Ops.DEFINE_GLOBAL] for u in stores]))
assert len(mutable_bufs) == len(stores) == 2
self.assertSetEqual(set([u.arg for u in mutable_bufs]), set([0,1]))
def _test_no_nested_ranges(self, lins, skip=None):
for l in lins:
range_in_acc = flatten([[x for x in u.src if x.op is Ops.RANGE] for u in l.uops if u.op is Ops.DEFINE_REG])
@@ -437,41 +417,6 @@ class TestLinearizer(unittest.TestCase):
# the global store doesn't change
assert stores[1].src[1].dtype == dtypes.float
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_skip_unmatching_upcasts(self):
Tensor.manual_seed(0)
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=0, src=())
c1 = c0.view(ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),)))
c2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(9600), arg=1, src=())
c3 = c2.view(ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)))
c4 = c3.load()
c5 = c1.store(c4)
ast = c5.sink()
opt = [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16),
Opt(op=OptOps.LOCAL, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=3, arg=2)]
helper_linearizer_ast(ast, [Tensor.randn(240*40).realize()], opts=[opt])
out = [u for u in get_program(ast, opts=opt).uops if u.op is Ops.STORE][0]
assert out.src[1].op is Ops.VECTORIZE and out.src[1].dtype == dtypes.float.vec(4)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4")
def test_skip_unmatching_upcasts_with_gep(self):
Tensor.manual_seed(0)
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=0, src=())
c1 = c0.view(ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)))
c2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(256), arg=1, src=())
c3 = c2.view(ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),)))
c4 = c3.load()
c5 = c1.store(c4)
ast = c5.sink()
opt = [Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=1, arg=8),
Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8),
Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=2)]
helper_linearizer_ast(ast, [Tensor.randn(8*32).realize()], opts=[opt])
out = [u for u in get_program(ast).uops if u.op is Ops.STORE][0]
assert out.src[1].op is Ops.VECTORIZE and out.src[1].dtype.count != 1
# *** helpers ***
def helper_realized_ast(r:Tensor|list[Tensor]) -> tuple[UOp, list[Buffer]]:

View File

@@ -32,6 +32,7 @@ class TestLinearizerFailure(unittest.TestCase):
class TestLinearizerDumb(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
@unittest.skip("Ops.VALID no longer exists")
def test_max_simplify_and_cancel(self):
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1000), arg=0, src=())
c1 = c0.view(ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))
@@ -54,6 +55,7 @@ class TestLinearizerDumb(unittest.TestCase):
# this was a bug in embedding, someday we should fold this anyway
@unittest.skipUnless(is_dtype_supported(dtypes.half), f"half dtype not supported on {Device.DEFAULT}")
@unittest.skip("UOp.view is no longer supported")
def test_llama_embedding(self):
c0 = UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(4096), arg=0, src=())
c1 = c0.view(ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)))

View File

@@ -1954,8 +1954,7 @@ class TestSchedule(unittest.TestCase):
a = Tensor([1,2,3,4]).realize()
for _ in range(24): a = a + a
new_uop = a.reshape(4,1).realize().uop
self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1)))
self.assertEqual(swizzle_cnt(new_uop), 0)
assert new_uop.base.op is Ops.BUFFER
@unittest.skipIf(CI and Device.DEFAULT == "NV", "crashes on NV CI")
def test_limit_bufs_with_var(self):
@@ -1981,9 +1980,6 @@ class TestSchedule(unittest.TestCase):
sched = z.schedule()
self.assertEqual(len(sched), kcount+1)
def swizzle_cnt(u:UOp) -> int:
return len([x for x in u.toposort() if x.op is Ops.VIEW and len(x.src) != 0 and x.src[0].op not in {Ops.BUFFER, Ops.DEFINE_GLOBAL, Ops.ASSIGN}])
class TestSwizzle(unittest.TestCase):
def test_swizzle_simple(self):
Tensor.manual_seed(0)

View File

@@ -1,8 +1,6 @@
from typing import Optional, Any
import unittest, math
import numpy as np
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View # noqa F401
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.helpers import CI, DEBUG, getenv, Timing
from tinygrad.dtype import dtypes, DType, AddrSpace
@@ -492,15 +490,6 @@ class TestUOpMethod(unittest.TestCase):
self.assertIs(x.replace(arg=None).arg, None)
with self.assertRaises(AssertionError): x.replace(field="a")
def test_device(self):
x = UOp(Ops.VIEW, dtypes.int, (UOp.new_buffer(Device.DEFAULT, 1, dtypes.int), UOp.const(dtypes.int, 1)), ShapeTracker.from_shape(()))
self.assertEqual(x.device, Device.DEFAULT)
# NOTE: CONST doesn't have device
buffer, const = x.src
self.assertEqual(buffer.device, Device.DEFAULT)
self.assertEqual(const._device, None)
with self.assertRaises(AssertionError): const.device
class TestUOpStr(unittest.TestCase):
def test_uop_str(self):
a = UOp(Ops.CONST, dtypes.float, (), 2.0) + UOp(Ops.CONST, dtypes.float, (), 3.0)

View File

@@ -6,7 +6,6 @@ from tinygrad.uop.ops import UPat, Ops, UOp
realized_pattern = UPat(Ops.BUFFER)
# after realization, base tensor uops become RESHAPE(BUFFER)
buffer_view_pattern = UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),))
const_pattern = UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),),)))
def is_pattern_uop(u:UOp, pat:UPat): assert pat.match(u, {}), f"{u}\nis not\n{pat}"
def is_pattern(ten:Tensor, pat:UPat): is_pattern_uop(ten.uop, pat)

View File

@@ -1,97 +0,0 @@
from __future__ import annotations
import unittest
from tinygrad import Tensor
from tinygrad.helpers import DEBUG, RANGEIFY
from tinygrad.uop.ops import UOp, Ops, print_uops
from tinygrad.uop.spec import type_verify, ast_spec, tensor_uop_spec
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad import dtypes
from tinygrad.shape.view import View
from tinygrad.engine.realize import get_program
from tinygrad.device import Device
class InvalidASTException(Exception): pass
def helper_test_verify_ast(*stores:UOp):
sink = UOp(Ops.SINK, dtypes.void, stores)
if DEBUG >= 3:
for op in stores: print(op)
try: type_verify(list(sink.toposort()), ast_spec)
except RuntimeError as e: raise InvalidASTException(e.args)
program = get_program(sink, Device[Device.DEFAULT].renderer)
if DEBUG >= 6: print_uops(program.uops)
if DEBUG >= 4: print(program.src)
class TestUOpSpec(unittest.TestCase):
def test_tiny_add(self):
dtype = dtypes.int
buf_0 = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 0)
buf_1 = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 1)
buf_2 = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 2)
a = UOp(Ops.LOAD, dtype, (buf_1.view(ShapeTracker.from_shape((32, 1))),))
b = UOp(Ops.LOAD, dtype, (buf_2.view(ShapeTracker.from_shape((32, 1))),))
store = UOp(Ops.STORE, dtypes.void, (buf_0.view(ShapeTracker.from_shape((32, 1))), a+b))
helper_test_verify_ast(store)
def test_no_implicit_broadcasting(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
a = UOp(Ops.LOAD, dtypes.float, (bufs[1].view(ShapeTracker.from_shape((4, 32))),))
b = a + UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.MAX, (1,)))
st = UOp(Ops.STORE, dtypes.void, (bufs[0].view(ShapeTracker.from_shape((4, 32))), b))
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
def test_shrink_ok(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
a = UOp(Ops.LOAD, dtypes.float, (bufs[1].view(ShapeTracker((View((32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),))),))
b = UOp(Ops.LOAD, dtypes.float, (bufs[1].view(ShapeTracker((View((32, 32), strides=(0, 1), offset=0, mask=None, contiguous=False),))),))
st = UOp.store(bufs[0].view(ShapeTracker.from_shape((32, 32))), a+b)
helper_test_verify_ast(st)
def test_reduce_store(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
a = UOp(Ops.LOAD, dtypes.float, (bufs[1].view(ShapeTracker.from_shape((32, 1))),))
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
st = UOp.store(bufs[0].view(ShapeTracker.from_shape((32, 1))), r)
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
def test_reduce_add_store(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
a = UOp(Ops.LOAD, dtypes.float, (bufs[1].view(ShapeTracker.from_shape((32, 1))),))
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
st = UOp.store(bufs[0].view(ShapeTracker.from_shape((32, 1))), r+a)
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
def test_assert_swizzle(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
a = UOp(Ops.LOAD, dtypes.float, (buf.view(ShapeTracker.from_shape((32, 1))),))
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (Ops.ADD, (0,)))
st = UOp.store(buf.view(ShapeTracker.from_shape((32, 1))), r.view(r.st.expand((32, 1)))+a)
with self.assertRaisesRegex(InvalidASTException, "UOp verification failed"): helper_test_verify_ast(st)
def test_const_view_always_valid(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
a = UOp.const(dtypes.int, 0).replace(src=(UOp(Ops.VIEW, dtypes.void, (), ShapeTracker.from_shape(())),))
st = UOp.store(buf.view(ShapeTracker.from_shape(())), a.cast(dtypes.float))
helper_test_verify_ast(st)
@unittest.skipIf(RANGEIFY, "RANGEIFY does not push views")
def test_assert_masked_view_in_const(self):
t = Tensor(6).uop
a = t.replace(src=(t.src[0].replace(arg=t.st.reshape((1,)).pad(((0, 1),))),))
with self.assertRaisesRegex(RuntimeError, "UOp verification failed"):
type_verify([a], tensor_uop_spec)
class TestUOpSink(unittest.TestCase):
def test_0(self):
s = UOp.sink()
self.assertEqual(len(s.src), 0)
def test_1(self):
a = UOp.const(dtypes.int, 0)
s1 = UOp.sink(a)
s2 = a.sink()
self.assertIs(s1, s2)
if __name__ == '__main__':
unittest.main()

View File

@@ -1,6 +1,6 @@
# the job of the lowerer is to do indexing
from dataclasses import dataclass
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite, resolve
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint_to_uop, AxisType, graph_rewrite
# ***** indexing *****
@@ -15,8 +15,8 @@ def shape_to_idx(s, axis_types, start=0):
def get_index(ast:UOp) -> IndexContext:
axis_types = ast.arg.axis_types if isinstance(ast.arg, KernelInfo) else ()
if len(ast.full_shape) != len(axis_types) and ast.st is not None:
axis_types = tuple([AxisType.REDUCE if resolve(s != fs) else AxisType.LOOP for s,fs in zip(ast.shape, ast.full_shape)])
#if len(ast.full_shape) != len(axis_types) and ast.st is not None:
# axis_types = tuple([AxisType.REDUCE if resolve(s != fs) else AxisType.LOOP for s,fs in zip(ast.shape, ast.full_shape)])
return IndexContext(axis_types, [], 0)
# ***** lowering (given index) *****
@@ -26,29 +26,6 @@ def subblock(ctx: IndexContext, full_new_idx: list[UOp], src: UOp):
ctx.start = lc.start
return graph_rewrite(src, pm_lowerer, lc, name="subblock", bottom_up=True)
def lower_reduce_axis(ctx: IndexContext, x: UOp):
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
full_new_idx = list(ctx.idxs)
for a in x.axis_arg: full_new_idx[a] = new_idxs[a]
ret = subblock(ctx, full_new_idx, x.src[0])
return UOp(Ops.REDUCE, x.dtype, (ret,)+tuple([full_new_idx[i] for i in x.axis_arg]), x.arg[0])
def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
# TODO: reenable after REDUCE_AXIS is fixed
#assert x.src[1].shape == x.src[0].shape, f"shape mismatch on store {x.src[1].shape} != {x.src[0].shape}"
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
idx = x.st_arg.to_valid_uop(new_idxs)
used_idxs = [x for x in idx.toposort() if x in new_idxs]
real_new_idxs = []
for i in range(len(x.src[0].shape)):
if new_idxs[i] in used_idxs or len(ctx.idxs) <= i: real_new_idxs.append(new_idxs[i])
else: real_new_idxs.append(ctx.idxs[i])
stored = subblock(ctx, real_new_idxs, x.src[1])
used_ranges = [x for x in used_idxs if x.op is Ops.RANGE]
return buf.index(idx).store(stored, *used_ranges)
def fixup_wmma(ctx:IndexContext, x:UOp):
if x.tag is not None: return None
new_idxs = shape_to_idx(x.src[0].shape, ctx.axis_types, ctx.start)
@@ -63,21 +40,6 @@ def fixup_wmma(ctx:IndexContext, x:UOp):
return x.replace(src=srcs, arg=x.arg[:-2]+(new_x_arg_m2, new_x_arg_m1), tag=1)
pm_lowerer = PatternMatcher([
# TODO: remove these hacks
# hack for old style CONST(VIEW) (now it's just VIEW(CONST))
(UPat((Ops.DEFINE_VAR, Ops.CONST), src=(UPat(Ops.VIEW, name="v"),), name="c"), lambda c,v: c.replace(src=()).view(v.arg)),
# hack for old style VALID (now it's just VIEW(CONST))
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c"), UPat(Ops.CONST, arg=0)), lambda c,v: c.replace(src=()).view(v.arg)),
# consts and loads
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),), name="view"),
lambda ctx,view,c: c if all(x.mask is None for x in view.arg.views) else view.arg.to_valid_uop(ctx.idxs).get_valid().where(c, c.const_like(0))),
(UPat(Ops.LOAD, src=(UPat.var("buf").view(),), allow_any_len=True, name="x"),
lambda ctx,buf,x: UOp(Ops.LOAD, x.dtype, (buf.index(x.st_arg.to_valid_uop(ctx.idxs)),)+x.src[1:])),
# reduce/view_const
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
(UPat(Ops.STORE, src=(UPat.var("buf").view(),), allow_any_len=True, name="x"), lower_store),
(UPat(Ops.WMMA, name="x"), fixup_wmma),
# axis fixups for WMMA

View File

@@ -27,13 +27,13 @@ pm_quant = symbolic+PatternMatcher([
(UPat.var("x")*UPat.cvar("c1", dtype=dtypes.floats) + UPat.var("y")*UPat.cvar("c2", dtype=dtypes.floats),
lambda x,y,c1,c2: (x+y)*c1 if abs(c1.arg-c2.arg) < 1e-9 else None),
# mul 0 * c1 is 0
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1),
#(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
# UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1),
# mul (with plus) 0 * c1 is 0
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
(UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int) + \
UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
lambda ld,v,c1: ld*c1),
#(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
# (UPat(Ops.LOAD, src=(UPat().view(name="v"),)).cast(dtypes.int) + \
# UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
# lambda ld,v,c1: ld*c1),
# const push through add
((UPat.var("x")*UPat.cvar("c1") + UPat.var("y")*UPat.cvar("c2")) * UPat.cvar("c3"), lambda x,y,c1,c2,c3: (x*c1*c3) + (y*c2*c3)),
@@ -64,4 +64,4 @@ pm_quant = symbolic+PatternMatcher([
lambda v1,v2,c1,r: r.replace(src=(v1*v2,)) + r.replace(src=(c1*v2,))),
(UPat(Ops.REDUCE_AXIS, src=((UPat(Ops.CAST, name="v1")+UPat.var("c1")) * (UPat(Ops.CAST, name="v2",)+UPat.var("c2")),), name="r"),
lambda v1,v2,c1,c2,r: r.replace(src=(v1*v2,)) + r.replace(src=(c2*v1,)) + r.replace(src=(c1*v2,)) + r.replace(src=(c1*c2,))),
])
])

View File

@@ -29,12 +29,6 @@ class Ops(FastEnum):
RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); FLIP = auto() # noqa: E702
MULTI = auto() # MULTI is really a movement op
# view is what all movement ops become
VIEW = auto()
# TODO: remove VALID with the VIEW(CONST(DEVICE)) refactor
VALID = auto()
# TODO: unify these ops into the levels of the memory hierarchy. depends on ASSIGN is STORE
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_REG = auto() # noqa: E702
@@ -96,7 +90,7 @@ class GroupOp:
Irreducible = {Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE}
Movement = {Ops.RESHAPE, Ops.EXPAND, Ops.PERMUTE, Ops.PAD, Ops.SHRINK, Ops.FLIP}
Buffer = {Ops.LOAD, Ops.STORE, Ops.VALID, Ops.CONST, Ops.DEFINE_VAR}
Buffer = {Ops.LOAD, Ops.STORE, Ops.CONST, Ops.DEFINE_VAR}
Block = {Ops.BLOCK, Ops.BLOCKEND, Ops.BLOCKSTART}
# BinaryOps that can be flipped

View File

@@ -7,7 +7,7 @@ from tinygrad.uop import Ops, GroupOp
from tinygrad.uop.mathtraits import MathTrait
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, RANGEIFY, VIZ, SPEC
from tinygrad.helpers import PICKLE_BUFFERS, PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC
from tinygrad.helpers import strip_parens
if TYPE_CHECKING:
from tinygrad.shape.shapetracker import ShapeTracker
@@ -186,8 +186,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.BARRIER: return None
if self.op in GroupOp.Block: return None
from tinygrad.shape.shapetracker import ShapeTracker
# VIEW and MovementOps define a new ShapeTracker from the arg
if self.op is Ops.VIEW: return self.arg
# MovementOps define a new ShapeTracker from the arg
if self.op is Ops.BUFFERIZE: return ShapeTracker.from_shape(tuple([int(r.vmax+1) for r in self.src[1:]]))
# allow reshape from nothing
if self.op is Ops.RESHAPE and self.src[0].st is None: return ShapeTracker.from_shape(self.arg)
@@ -198,7 +197,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.STORE and self.dtype is not dtypes.void: return self.src[0].src[0].st
# BufferOps and ASSIGN flow ShapeTracker from a direct edge
if self.op in {Ops.STORE, Ops.ASSIGN, Ops.LOAD}: return self.src[0].st
if self.op in GroupOp.Buffer: return views[0] if (views:=[x.st for x in self.src if x.op is Ops.VIEW]) else None
# BUFFER/BUFFER_VIEW and KERNEL only have a size
if self.op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return ShapeTracker.from_shape((self.size,))
@@ -229,12 +227,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
case _: shape = src_sts[0].shape
return ShapeTracker.from_shape(shape)
@functools.cached_property
def full_shape(self) -> tuple[sint, ...]:
if self.op is Ops.VIEW: return self.shape
# NOTE: if a parent doesn't have st its full_shape is empty
parent_shapes = [x.full_shape for x in self.src]
return tuple(smax(x) for x in itertools.zip_longest(*parent_shapes, fillvalue=1))
@property
def shape(self) -> tuple[sint, ...]:
assert self.st is not None, f"{self.op} doesn't have a shape"
@@ -345,17 +337,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
ret = UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype), src=() if src is None else (src,))
if RANGEIFY:
# VIEW on const is no longer supported in RANGEIFY
if device is not None: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
else:
if shape is not None:
from tinygrad.shape.shapetracker import ShapeTracker
ret = ret.replace(src=(UOp(Ops.VIEW, dtypes.void, (), ShapeTracker.from_shape(shape, (0,)*len(shape))),))
if device is not None:
if shape is not None: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device).view(unwrap(ret.st)),))
else: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
if device is not None: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
if shape is not None: ret = ret.reshape((1,)*len(shape)).expand(shape)
return ret
@staticmethod
def range(end:sint, *arg):
@@ -455,10 +438,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@property
def base(self) -> UOp:
if (self.op is Ops.VIEW and len(self.src) != 0) or self.op in GroupOp.Movement: return self.src[0].base
if self.op in GroupOp.Movement: return self.src[0].base
if self.op is Ops.MULTI: return self.src[0].base # MULTI is really a VIEW
return self
def view(self, new_st:ShapeTracker) -> UOp: return UOp(Ops.VIEW, self.dtype, (self,), new_st)
def _mop(self, op:Ops, arg) -> UOp:
ret = UOp(op, self.dtype, (self,), arg)
@@ -571,8 +553,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
all_vars = set([x for x in self.toposort() if x.op is Ops.DEFINE_VAR])
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
def variables(self) -> list[Variable]:
st_vars: list[set[Variable]] = [x.arg.vars() for x in self.toposort() if x.op is Ops.VIEW]
return sorted(set.union(*st_vars, set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()])), key=lambda v: v.arg)
return sorted(set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
# *** uop symbolic stuff ***
@@ -789,7 +770,6 @@ class UPat(MathTrait):
# copied from UOp
def sink(self, *srcs:UPat|None, **kwargs): return UPat(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def index(self, idx:UPat, valid:UPat|None=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def view(self, st=None, **kwargs): return UPat(Ops.VIEW, self.dtype, (self,), st, **kwargs)
def cast(self, dtype=None, **kwargs): return UPat(Ops.CAST, dtype, (self,), **kwargs)
def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,))
def gep(self, i:int|None=None, **kwargs): return UPat(Ops.GEP, None, (self,), (i,) if i is not None else None, **kwargs)
@@ -1145,7 +1125,6 @@ renderer = PatternMatcher([
(UPat(Ops.MULACC, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
(UPat(Ops.WHERE, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
(UPat(set(syms.keys()), src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.op]}{x.src[1].arg})")),
(UPat(Ops.VIEW, src=(UPat(Ops.NOOP),), name="x"), lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.view({x.arg})")),
(UPat((Ops.INDEX, Ops.BUFFERIZE), name="x"), lambda x:
UOp(Ops.NOOP, arg=''.join([f"[{strip_parens(y.arg)}]" for y in x.src[1:]])) if all(y.op is Ops.NOOP for y in x.src[1:]) else None),
(UPat(Ops.VECTORIZE, src=UPat(Ops.NOOP), name="x"),
@@ -1172,8 +1151,6 @@ pm_pyrender = PatternMatcher([
arg=f"{x.src[0].arg}.{sugar[x.op]}({', '.join([y.arg for y in x.src[1:]] + ([f'arg={str(x.arg)}'] if x.arg is not None else []))})")),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.NOOP),), name="x"),
lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.f({x.op}, arg=({', '.join([str(y) for y in x.arg])}))")),
(UPat(Ops.VALID, src=(UPat(Ops.NOOP),), name="x"),
lambda x: UOp(Ops.NOOP, arg=f"{x.src[0].arg}.f({x.op}, dtype=dtypes.bool)")),
])
@Context(SPEC=0)
@@ -1182,8 +1159,8 @@ def pyrender(ast:UOp) -> list[str]:
to_render = set()
for u in ast.toposort():
if u.op is Ops.STORE: to_render.add(u.src[1])
if len(cmap[u]) == 1 and u.op not in {Ops.DEFINE_GLOBAL, Ops.VIEW, Ops.LOAD} or u.op in {Ops.CONST}: continue
if u.op in {Ops.SINK, Ops.VIEW}:
if len(cmap[u]) == 1 and u.op not in {Ops.DEFINE_GLOBAL, Ops.LOAD} or u.op in {Ops.CONST}: continue
if u.op in {Ops.SINK}:
for s in u.src: to_render.add(s)
to_render.add(u)
ret: list[str] = []

View File

@@ -1,8 +1,7 @@
from typing import cast, Callable
from tinygrad.uop.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops, python_alu, graph_rewrite, AxisType
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType, AddrSpace, Invalid
from tinygrad.helpers import all_same, prod, DEBUG, IGNORE_OOB, Context, cpu_profile, RANGEIFY
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.helpers import all_same, prod, DEBUG, IGNORE_OOB, Context, cpu_profile
try:
import z3
# older versions of z3 dont have some operators like & overloaded
@@ -64,8 +63,6 @@ buffer_spec = PatternMatcher([
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"),
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)),
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.MSTACK, src=UPat(Ops.BUFFER)),)), lambda: True),
# allow VIEW here. TODO: what views specifically are allowed? does this mess with gradient?
(UPat(Ops.VIEW), lambda: True),
])
assign_spec = PatternMatcher([
@@ -92,17 +89,10 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
# this is fine as long as it's a realized buffer or const and base dtypes match.
((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base \
and x.base.op in {Ops.BUFFER,Ops.ASSIGN,Ops.CONST})),
(UPat(Ops.VIEW, src=(UPat.var("x"),)), lambda x: x.base.op in {Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN, Ops.CONST, Ops.DEVICE}),
# Tensor variable bindings
(UPat(Ops.BIND, (dtypes.int,dtypes.index,), (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=(dtypes.int,dtypes.index,))), arg=None), lambda: True),
# Tensor const has a device and an unmasked ShapeTracker of stride 0
# NOTE: variables in shape can cause multiple views in this ShapeTracker and other issues, see TestSymbolicJit.test_ones_sum
# TODO: remove after rangeify is default
(UPat(Ops.CONST, src=(UPat.any(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="st"),
UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.BIND)), name="st")),)),
lambda st: len(st.st.views) == 1 and all(v.mask is None for v in st.st.views)),
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),)), lambda: True),
# DETACH and CONTIGUOUS change how we interpret the source UOp
@@ -167,20 +157,8 @@ spec = PatternMatcher([
all(isinstance(ra, int) for ra in rng.arg[0:-1]) and isinstance(rng.arg[-1], AxisType)),
(UPat(Ops.SPECIAL, src=(UPat.var("x"),), name="s"), lambda s,x: s.dtype == x.dtype == dtypes.int32 and isinstance(s.arg, str)),
(UPat(Ops.VIEW, dtypes.void, src=(), name="x"), lambda x: isinstance(x.arg, ShapeTracker)),
(UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"),
lambda x,src: isinstance(x.arg, ShapeTracker) and src.op is not Ops.STORE and x.dtype.base == src.dtype.base),
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
# early LOAD has a <bufview, store?>
(UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)),)), lambda: True),
(UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)), UPat(Ops.STORE))), lambda: True),
# early STORE has a <bufview, val>
(UPat(Ops.STORE, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)), UPat())), lambda: True),
# **** new style load/store ****
# make sure all index dtypes have been lowered
@@ -243,20 +221,12 @@ spec = PatternMatcher([
# *** this is the UOp AST spec ***
ast_spec = PatternMatcher([
# VIEW can only exist in the edges
(UPat(Ops.VIEW, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL),))), lambda: True),
(UPat(Ops.VIEW, name="view"), lambda view: len(view.src) == 0),
# all parent UOps must have the same shape
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: all_same([x.shape for x in root.src if x.st is not None])),
])
# *** this spec should match all UOps ever created ***
full_non_rangeify_spec = PatternMatcher([]) if RANGEIFY else PatternMatcher([
# in non rangeify const can still have a View, and sometimes a FUSE while propagating
(UPat((Ops.VIEW, Ops.FUSE)).f(Ops.CONST), lambda: True),
])
full_spec = PatternMatcher([
# SENTINEL should never be in the graph
(UPat(Ops.SENTINEL), lambda: False),
@@ -310,7 +280,7 @@ full_spec = PatternMatcher([
(UPat(Ops.DEFINE_VAR), lambda: True),
# reshape on STORE
(UPat(Ops.RESHAPE, src=(UPat(Ops.STORE),)), lambda: True),
])+full_non_rangeify_spec+tensor_uop_spec+spec
])+tensor_uop_spec+spec
# ***** uop helpers *****

View File

@@ -16,7 +16,7 @@ from tinygrad.codegen.opt import axis_colors
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_REG: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BUFFER_VIEW: "#E5EAFF",
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.FUSE: "#FFa500",
Ops.ALLREDUCE: "#ff40a0", Ops.MSELECT: "#d040a0", Ops.MSTACK: "#d040a0", Ops.CONTIGUOUS: "#FFC14D",
@@ -62,14 +62,9 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
for u in (toposort:=x.toposort()):
# always exclude DEVICE/CONST/UNIQUE
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE} and u is not x: excluded.add(u)
# only exclude CONST VIEW source if it has no other children in the graph
if u.op is Ops.CONST and u.st is not None: excluded.update(u.src)
for u in toposort:
if u in excluded: continue
argst = codecs.decode(str(u.arg), "unicode_escape")
if u.op is Ops.VIEW:
argst = ("\n".join([f"{shape_to_str(v.shape)} / {shape_to_str(v.strides)}"+("" if v.offset == 0 else f" / {srender(v.offset)}")+
(f"\nMASK {mask_to_str(v.mask)}" if v.mask is not None else "") for v in unwrap(u.st).views]))
if u.op in GroupOp.Movement: argst = (mask_to_str if u.op in {Ops.SHRINK, Ops.PAD} else shape_to_str)(u.arg)
label = f"{str(u.op).split('.')[1]}{(chr(10)+word_wrap(argst.replace(':', ''))) if u.arg is not None else ''}"
if u.dtype != dtypes.void: label += f"\n{u.dtype}"
@@ -80,7 +75,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
try:
if len(rngs:=u.ranges):
label += f"\n({','.join([colored(range_str(x), axis_colors[x.arg[-1]]) for x in sorted(rngs, key=lambda x: x.arg[0:-1])])})"
if u.op not in {Ops.VIEW, Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None:
if u.op not in {Ops.BUFFER, Ops.KERNEL, Ops.ASSIGN, Ops.COPY, Ops.SINK, *GroupOp.Buffer} and u.st is not None:
label += f"\n{shape_to_str(u.shape)}"
if u.op in {Ops.INDEX, Ops.BUFFERIZE}:
label += f"\n{u.render()}"