mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
cleanup uop eq asserts for swizzle [run_process_replay] (#6362)
* cleanup uop eq asserts for swizzle [run_process_replay] * more stuff
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import sys, unittest, time
|
||||
from typing import Callable, Optional, Set, Tuple, TypeVar
|
||||
import sys, time
|
||||
from typing import Callable, Tuple, TypeVar
|
||||
import numpy as np
|
||||
from test.external.process_replay.helpers import print_diff
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.ops import UOp, UOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -56,22 +57,10 @@ def rand_for_dtype(dt:DType, size:int):
|
||||
return np.random.choice([True, False], size=size)
|
||||
return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt))
|
||||
|
||||
class TestUOps(unittest.TestCase):
|
||||
def assert_equiv_uops(self, uop1:UOp, uop2:UOp, cache:Optional[Set[Tuple[UOp, UOp]]]=None):
|
||||
if cache is None: cache = set()
|
||||
if (uop1, uop2) in cache: return
|
||||
cache.add((uop1, uop2))
|
||||
# NOTE: direct UOps __eq__ is comparing object reference, use this function to compare two uops
|
||||
try:
|
||||
self.assertIs(uop1.op, uop2.op)
|
||||
self.assertEqual(uop1.dtype, uop2.dtype)
|
||||
self.assertEqual(uop1.arg, uop2.arg)
|
||||
self.assertEqual(len(uop1.src), len(uop2.src))
|
||||
for s1, s2 in zip(uop1.src, uop2.src): self.assert_equiv_uops(s1, s2, cache)
|
||||
except AssertionError as e:
|
||||
print(f"{uop1=}")
|
||||
print(f"{uop2=}")
|
||||
raise e
|
||||
def assert_equiv_uops(u1:UOp, u2:UOp) -> None:
|
||||
if u1.key != u2.key:
|
||||
print_diff(u1, u2)
|
||||
raise AssertionError("uops aren't equal.")
|
||||
|
||||
def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]) -> UOp:
|
||||
return UOp(UOps.CONST, dtype, (ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),),
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import unittest
|
||||
import time
|
||||
import numpy as np
|
||||
from test.helpers import TestUOps
|
||||
from test.helpers import assert_equiv_uops
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import lower_schedule_item, run_schedule
|
||||
|
||||
class TestFusionOp(TestUOps):
|
||||
class TestFusionOp(unittest.TestCase):
|
||||
def test_contiguous_add(self):
|
||||
def test(contig=False):
|
||||
bt = Tensor(np.arange(16), dtype=dtypes.float32).reshape(4,4)
|
||||
@@ -43,8 +43,8 @@ class TestFusionOp(TestUOps):
|
||||
c = Tensor([1,2,3,4])
|
||||
for _ in range(23): c = c + c
|
||||
sched3 = create_schedule([c.lazydata], None)
|
||||
self.assert_equiv_uops(sched1[-1].ast, sched2[-1].ast)
|
||||
with self.assertRaises(AssertionError): self.assert_equiv_uops(sched1[-1].ast, sched3[-1].ast)
|
||||
assert_equiv_uops(sched1[-1].ast, sched2[-1].ast)
|
||||
with self.assertRaises(AssertionError): assert_equiv_uops(sched1[-1].ast, sched3[-1].ast)
|
||||
self.assertLess(time.perf_counter()-st, 2.0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import unittest, itertools
|
||||
from test.helpers import TestUOps
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOps, UOp, PatternMatcher, UPat, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401
|
||||
|
||||
class TestPatternMatcher(TestUOps):
|
||||
class TestPatternMatcher(unittest.TestCase):
|
||||
def test_simple_match(self):
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import unittest, pickle
|
||||
import numpy as np
|
||||
from test.helpers import TestUOps
|
||||
from test.helpers import assert_equiv_uops
|
||||
from tinygrad import Tensor, TinyJit, Variable
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
|
||||
class TestPickle(TestUOps):
|
||||
class TestPickle(unittest.TestCase):
|
||||
def test_pickle_realized_tensor(self):
|
||||
t = Tensor.rand(10, 10).realize()
|
||||
st = pickle.dumps(t)
|
||||
@@ -64,7 +64,7 @@ class TestPickle(TestUOps):
|
||||
sched = create_schedule([out.lazydata])
|
||||
pk = pickle.dumps(sched)
|
||||
sched_pk = pickle.loads(pk)
|
||||
self.assert_equiv_uops(sched_pk[-1].ast, sched[-1].ast)
|
||||
assert_equiv_uops(sched_pk[-1].ast, sched[-1].ast)
|
||||
|
||||
class TestPickleJIT(unittest.TestCase):
|
||||
@classmethod
|
||||
|
||||
@@ -14,9 +14,9 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite
|
||||
from tinygrad.helpers import AST_REWRITE, CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import create_schedule, get_output_st, reduceop_fusor, st_fixup
|
||||
from tinygrad.engine.schedule import create_schedule, get_output_st, reduceop_fusor, st_fixup, ScheduleItem
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
||||
from test.helpers import is_dtype_supported, Context, timeit
|
||||
from test.helpers import assert_equiv_uops, is_dtype_supported, Context, timeit
|
||||
from tinygrad.lazy import LazyBuffer, view_supported_devices
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
||||
@@ -1299,7 +1299,7 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(check_schedule(out, 3)) # TODO: push a reduceop through a reshape
|
||||
|
||||
class TestConvBW(unittest.TestCase):
|
||||
def check_schedule(self, xt, cnt:int, flops=None):
|
||||
def check_schedule(self, xt, cnt:int, flops=None) -> List[ScheduleItem]:
|
||||
with Context(FUSE_CONV_BW=getenv("FUSE_CONV_BW", 1), NOOPT=flops is not None):
|
||||
s = create_schedule(flatten([r.lazydata.lbs for r in xt]))
|
||||
kernels = [si for si in s if si.ast.op is UOps.SINK]
|
||||
@@ -1308,6 +1308,7 @@ class TestConvBW(unittest.TestCase):
|
||||
run_schedule(s)
|
||||
if flops is not None: assert GlobalCounters.global_ops <= flops, f"too many ops {GlobalCounters.global_ops}"
|
||||
if FUSE_CONV_BW: self.assertEqual(len(kernels), cnt)
|
||||
return kernels
|
||||
|
||||
def test_fold_conv_relu_backward(self):
|
||||
c1 = nn.Conv2d(3,16,3, bias=False)
|
||||
@@ -1342,7 +1343,7 @@ class TestConvBW(unittest.TestCase):
|
||||
img = Tensor(img_np, requires_grad=True)
|
||||
c1(img).relu().mean().backward()
|
||||
assert img.grad is not None and c1.weight.grad is not None
|
||||
with Context(AST_REWRITE=1): self.check_schedule([img.grad, c1.weight.grad], 3)
|
||||
with Context(AST_REWRITE=1): compare_ast = self.check_schedule([img.grad, c1.weight.grad], 3)[1].ast
|
||||
rw_flops = GlobalCounters.global_ops
|
||||
# ref
|
||||
GlobalCounters.reset()
|
||||
@@ -1351,7 +1352,7 @@ class TestConvBW(unittest.TestCase):
|
||||
img_ref = Tensor(img_np, requires_grad=True)
|
||||
c1_ref(img_ref).relu().mean().backward()
|
||||
assert img_ref.grad is not None and c1_ref.weight.grad is not None
|
||||
with Context(AST_REWRITE=0): self.check_schedule([img_ref.grad, c1_ref.weight.grad], 3)
|
||||
with Context(AST_REWRITE=0): ref_ast = self.check_schedule([img_ref.grad, c1_ref.weight.grad], 3)[1].ast
|
||||
ref_flops = GlobalCounters.global_ops
|
||||
# correctness
|
||||
np.testing.assert_allclose(c1.weight.grad.numpy(), c1_ref.weight.grad.numpy(), atol=5e-4, rtol=1e-5)
|
||||
@@ -1359,6 +1360,7 @@ class TestConvBW(unittest.TestCase):
|
||||
# flops, TODO: This will be fixed once SWIZZLE merges view strides.
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertEqual(rw_flops, ref_flops)
|
||||
assert_equiv_uops(compare_ast, ref_ast)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import List
|
||||
import unittest, time
|
||||
from test.helpers import TestUOps
|
||||
from test.helpers import assert_equiv_uops
|
||||
from tinygrad import dtypes, Variable, Device
|
||||
from tinygrad.dtype import PtrDType
|
||||
from tinygrad.helpers import DEBUG
|
||||
@@ -142,7 +142,7 @@ class TestGraphRewrite(unittest.TestCase):
|
||||
self.assertEqual(sink.src[1].op, UOps.CONST)
|
||||
self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 3)
|
||||
|
||||
class TestUOpGraph(TestUOps):
|
||||
class TestUOpGraph(unittest.TestCase):
|
||||
def test_add_constant_fold(self):
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
@@ -214,7 +214,7 @@ class TestUOpGraph(TestUOps):
|
||||
# possible
|
||||
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
|
||||
xyzw = tuple(UOp(UOps.GEP, dtypes.float, (val,), i) for i in range(4))
|
||||
self.assert_equiv_uops(_test_vec(xyzw), val)
|
||||
assert_equiv_uops(_test_vec(xyzw), val)
|
||||
|
||||
# unaligned
|
||||
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
|
||||
@@ -242,7 +242,7 @@ class TestUOpGraph(TestUOps):
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
|
||||
uops = to_uops_list([UOp(UOps.GEP, dtypes.float, (vec,), i) for i in range(vec_size)])
|
||||
for uop, const in zip(uops, consts):
|
||||
self.assert_equiv_uops(uop, const)
|
||||
assert_equiv_uops(uop, const)
|
||||
|
||||
def test_wmma_vectorize_fold(self):
|
||||
for i in [2, 4, 8]:
|
||||
@@ -251,7 +251,7 @@ class TestUOpGraph(TestUOps):
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assert_equiv_uops(uops[0], acc)
|
||||
assert_equiv_uops(uops[0], acc)
|
||||
self.assertEqual(len(uops), 1)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
@@ -260,7 +260,7 @@ class TestUOpGraph(TestUOps):
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assert_equiv_uops(uops[0], acc)
|
||||
assert_equiv_uops(uops[0], acc)
|
||||
self.assertEqual(len(uops), 1)
|
||||
|
||||
def test_wmma_vectorize_no_fold(self):
|
||||
@@ -272,7 +272,7 @@ class TestUOpGraph(TestUOps):
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assert_equiv_uops(uops[-1], wmma)
|
||||
assert_equiv_uops(uops[-1], wmma)
|
||||
|
||||
for i in [4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
||||
@@ -282,7 +282,7 @@ class TestUOpGraph(TestUOps):
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assert_equiv_uops(uops[-1], wmma)
|
||||
assert_equiv_uops(uops[-1], wmma)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
@@ -291,7 +291,7 @@ class TestUOpGraph(TestUOps):
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assert_equiv_uops(uops[-1], wmma)
|
||||
assert_equiv_uops(uops[-1], wmma)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
||||
@@ -300,7 +300,7 @@ class TestUOpGraph(TestUOps):
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
self.assert_equiv_uops(uops[-1], wmma)
|
||||
assert_equiv_uops(uops[-1], wmma)
|
||||
|
||||
def test_cast_alu_fold(self):
|
||||
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=0)
|
||||
@@ -346,9 +346,9 @@ class TestUOpGraph(TestUOps):
|
||||
uops = to_uops_list([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))])
|
||||
ld0, ld1 = uops[-1].src[2].src
|
||||
# ld0 becomes the invalid value
|
||||
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
|
||||
assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assert_equiv_uops(ld0, UOp.load(glbl2, idx, dtype=dtypes.int))
|
||||
assert_equiv_uops(ld0, UOp.load(glbl2, idx, dtype=dtypes.int))
|
||||
|
||||
def test_fold_gated_load_local(self):
|
||||
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
||||
@@ -361,9 +361,9 @@ class TestUOpGraph(TestUOps):
|
||||
uops = to_uops_list([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))])
|
||||
ld0, ld1 = uops[-1].src[2].src
|
||||
# ld0 becomes the invalid value
|
||||
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
|
||||
assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
|
||||
# the gate and invalid value are deleted from ld1
|
||||
self.assert_equiv_uops(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int))
|
||||
assert_equiv_uops(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int))
|
||||
|
||||
def test_fold_gated_store(self):
|
||||
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
||||
@@ -375,7 +375,7 @@ class TestUOpGraph(TestUOps):
|
||||
uops = to_uops_list([st0, st1])
|
||||
# only the second store happens
|
||||
self.assertEqual(len(uops), 4)
|
||||
self.assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val))
|
||||
assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val))
|
||||
|
||||
def test_asserts_bad_gate(self):
|
||||
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
||||
@@ -611,7 +611,7 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
|
||||
def gate_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + reducer)
|
||||
|
||||
class TestIFUOps(TestUOps):
|
||||
class TestIFUOps(unittest.TestCase):
|
||||
def test_create_ifs(self):
|
||||
gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
||||
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 4))
|
||||
@@ -627,7 +627,7 @@ class TestIFUOps(TestUOps):
|
||||
sink = gate_rewrite(sink)
|
||||
if_uops = [u for u in sink.parents if u.op is UOps.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assert_equiv_uops(if_uops[0].src[0], gate)
|
||||
assert_equiv_uops(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
self.assertEqual(len(st.src), 3)
|
||||
|
||||
@@ -645,7 +645,7 @@ class TestIFUOps(TestUOps):
|
||||
sink = gate_rewrite(sink)
|
||||
if_uops = [u for u in sink.parents if u.op is UOps.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assert_equiv_uops(if_uops[0].src[0], gate)
|
||||
assert_equiv_uops(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
self.assertEqual(len(st.src), 3)
|
||||
|
||||
@@ -661,7 +661,7 @@ class TestIFUOps(TestUOps):
|
||||
sink = gate_rewrite(sink)
|
||||
if_uops = [u for u in sink.parents if u.op is UOps.IF]
|
||||
self.assertEqual(len(if_uops), 1)
|
||||
self.assert_equiv_uops(if_uops[0].src[0], gate)
|
||||
assert_equiv_uops(if_uops[0].src[0], gate)
|
||||
for st in sink.src:
|
||||
self.assertEqual(len(st.src), 3)
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
|
||||
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from test.helpers import is_dtype_supported, TestUOps as TestEqUOps
|
||||
from test.helpers import is_dtype_supported, assert_equiv_uops
|
||||
|
||||
def to_uops_list(u:List[UOp], opts=None, skip_check=False) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check)
|
||||
|
||||
@@ -356,7 +356,7 @@ class TestUOpCompare(unittest.TestCase):
|
||||
mul = UOp(UOps.ALU, dtypes.float, (a, b), BinaryOps.MUL)
|
||||
assert (add < mul) or (mul < add), "add and mul with same src should have an order"
|
||||
|
||||
class TestUOpStr(TestEqUOps):
|
||||
class TestUOpStr(unittest.TestCase):
|
||||
def test_uop_str(self):
|
||||
a = UOp(UOps.CONST, dtypes.float, (), 2.0) + UOp(UOps.CONST, dtypes.float, (), 3.0)
|
||||
for _ in range(20): a = a + a
|
||||
@@ -368,7 +368,7 @@ class TestUOpStr(TestEqUOps):
|
||||
# nice big complicated uop
|
||||
with Context(NOOPT=1):
|
||||
sink = UOp(UOps.SINK, None, (get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops[-1],))
|
||||
self.assert_equiv_uops(sink, eval(str(sink)))
|
||||
assert_equiv_uops(sink, eval(str(sink)))
|
||||
|
||||
def test_nop_str(self):
|
||||
a = NOp(UOps.CONST, dtypes.float, (), 2.0, name="c0") + NOp(UOps.CONST, dtypes.float, (), 3.0, name="c1")
|
||||
|
||||
Reference in New Issue
Block a user