From e7f6b654ad011f6c318328180b8f335906211a4e Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 5 Sep 2024 13:36:36 +0800 Subject: [PATCH] cleanup uop eq asserts for swizzle [run_process_replay] (#6362) * cleanup uop eq asserts for swizzle [run_process_replay] * more stuff --- test/helpers.py | 25 +++++++----------------- test/test_fusion_op.py | 8 ++++---- test/test_pattern_matcher.py | 3 +-- test/test_pickle.py | 6 +++--- test/test_schedule.py | 12 +++++++----- test/test_uop_graph.py | 38 ++++++++++++++++++------------------ test/test_uops.py | 6 +++--- 7 files changed, 44 insertions(+), 54 deletions(-) diff --git a/test/helpers.py b/test/helpers.py index 772d3a090c..c0d22939de 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -1,6 +1,7 @@ -import sys, unittest, time -from typing import Callable, Optional, Set, Tuple, TypeVar +import sys, time +from typing import Callable, Tuple, TypeVar import numpy as np +from test.external.process_replay.helpers import print_diff from tinygrad import Tensor, Device, dtypes from tinygrad.ops import UOp, UOps from tinygrad.shape.shapetracker import ShapeTracker @@ -56,22 +57,10 @@ def rand_for_dtype(dt:DType, size:int): return np.random.choice([True, False], size=size) return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt)) -class TestUOps(unittest.TestCase): - def assert_equiv_uops(self, uop1:UOp, uop2:UOp, cache:Optional[Set[Tuple[UOp, UOp]]]=None): - if cache is None: cache = set() - if (uop1, uop2) in cache: return - cache.add((uop1, uop2)) - # NOTE: direct UOps __eq__ is comparing object reference, use this function to compare two uops - try: - self.assertIs(uop1.op, uop2.op) - self.assertEqual(uop1.dtype, uop2.dtype) - self.assertEqual(uop1.arg, uop2.arg) - self.assertEqual(len(uop1.src), len(uop2.src)) - for s1, s2 in zip(uop1.src, uop2.src): self.assert_equiv_uops(s1, s2, cache) - except AssertionError as e: - print(f"{uop1=}") - print(f"{uop2=}") - raise e +def assert_equiv_uops(u1:UOp, u2:UOp) -> None: + if u1.key != u2.key: + print_diff(u1, u2) + raise AssertionError("uops aren't equal.") def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]) -> UOp: return UOp(UOps.CONST, dtype, (ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),), diff --git a/test/test_fusion_op.py b/test/test_fusion_op.py index 0fdd8944ac..0337bf55a4 100644 --- a/test/test_fusion_op.py +++ b/test/test_fusion_op.py @@ -1,12 +1,12 @@ import unittest import time import numpy as np -from test.helpers import TestUOps +from test.helpers import assert_equiv_uops from tinygrad import Tensor, dtypes from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import lower_schedule_item, run_schedule -class TestFusionOp(TestUOps): +class TestFusionOp(unittest.TestCase): def test_contiguous_add(self): def test(contig=False): bt = Tensor(np.arange(16), dtype=dtypes.float32).reshape(4,4) @@ -43,8 +43,8 @@ class TestFusionOp(TestUOps): c = Tensor([1,2,3,4]) for _ in range(23): c = c + c sched3 = create_schedule([c.lazydata], None) - self.assert_equiv_uops(sched1[-1].ast, sched2[-1].ast) - with self.assertRaises(AssertionError): self.assert_equiv_uops(sched1[-1].ast, sched3[-1].ast) + assert_equiv_uops(sched1[-1].ast, sched2[-1].ast) + with self.assertRaises(AssertionError): assert_equiv_uops(sched1[-1].ast, sched3[-1].ast) self.assertLess(time.perf_counter()-st, 2.0) if __name__ == '__main__': diff --git a/test/test_pattern_matcher.py b/test/test_pattern_matcher.py index d88b468df3..41458ada14 100644 --- a/test/test_pattern_matcher.py +++ b/test/test_pattern_matcher.py @@ -1,9 +1,8 @@ import unittest, itertools -from test.helpers import TestUOps from tinygrad.dtype import dtypes from tinygrad.ops import UOps, UOp, PatternMatcher, UPat, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401 -class TestPatternMatcher(TestUOps): +class TestPatternMatcher(unittest.TestCase): def test_simple_match(self): matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float), lambda x: x)]) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) diff --git a/test/test_pickle.py b/test/test_pickle.py index 265ba40a22..050cb4cef3 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -1,10 +1,10 @@ import unittest, pickle import numpy as np -from test.helpers import TestUOps +from test.helpers import assert_equiv_uops from tinygrad import Tensor, TinyJit, Variable from tinygrad.engine.schedule import create_schedule -class TestPickle(TestUOps): +class TestPickle(unittest.TestCase): def test_pickle_realized_tensor(self): t = Tensor.rand(10, 10).realize() st = pickle.dumps(t) @@ -64,7 +64,7 @@ class TestPickle(TestUOps): sched = create_schedule([out.lazydata]) pk = pickle.dumps(sched) sched_pk = pickle.loads(pk) - self.assert_equiv_uops(sched_pk[-1].ast, sched[-1].ast) + assert_equiv_uops(sched_pk[-1].ast, sched[-1].ast) class TestPickleJIT(unittest.TestCase): @classmethod diff --git a/test/test_schedule.py b/test/test_schedule.py index c0656fc3b0..4ec168c3a7 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -14,9 +14,9 @@ from tinygrad.tensor import Tensor from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite from tinygrad.helpers import AST_REWRITE, CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP from tinygrad.codegen.kernel import Kernel, verify_ast -from tinygrad.engine.schedule import create_schedule, get_output_st, reduceop_fusor, st_fixup +from tinygrad.engine.schedule import create_schedule, get_output_st, reduceop_fusor, st_fixup, ScheduleItem from tinygrad.engine.realize import CompiledRunner, run_schedule -from test.helpers import is_dtype_supported, Context, timeit +from test.helpers import assert_equiv_uops, is_dtype_supported, Context, timeit from tinygrad.lazy import LazyBuffer, view_supported_devices from extra.models.llama import precompute_freqs_cis @@ -1299,7 +1299,7 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(out, 3)) # TODO: push a reduceop through a reshape class TestConvBW(unittest.TestCase): - def check_schedule(self, xt, cnt:int, flops=None): + def check_schedule(self, xt, cnt:int, flops=None) -> List[ScheduleItem]: with Context(FUSE_CONV_BW=getenv("FUSE_CONV_BW", 1), NOOPT=flops is not None): s = create_schedule(flatten([r.lazydata.lbs for r in xt])) kernels = [si for si in s if si.ast.op is UOps.SINK] @@ -1308,6 +1308,7 @@ class TestConvBW(unittest.TestCase): run_schedule(s) if flops is not None: assert GlobalCounters.global_ops <= flops, f"too many ops {GlobalCounters.global_ops}" if FUSE_CONV_BW: self.assertEqual(len(kernels), cnt) + return kernels def test_fold_conv_relu_backward(self): c1 = nn.Conv2d(3,16,3, bias=False) @@ -1342,7 +1343,7 @@ class TestConvBW(unittest.TestCase): img = Tensor(img_np, requires_grad=True) c1(img).relu().mean().backward() assert img.grad is not None and c1.weight.grad is not None - with Context(AST_REWRITE=1): self.check_schedule([img.grad, c1.weight.grad], 3) + with Context(AST_REWRITE=1): compare_ast = self.check_schedule([img.grad, c1.weight.grad], 3)[1].ast rw_flops = GlobalCounters.global_ops # ref GlobalCounters.reset() @@ -1351,7 +1352,7 @@ class TestConvBW(unittest.TestCase): img_ref = Tensor(img_np, requires_grad=True) c1_ref(img_ref).relu().mean().backward() assert img_ref.grad is not None and c1_ref.weight.grad is not None - with Context(AST_REWRITE=0): self.check_schedule([img_ref.grad, c1_ref.weight.grad], 3) + with Context(AST_REWRITE=0): ref_ast = self.check_schedule([img_ref.grad, c1_ref.weight.grad], 3)[1].ast ref_flops = GlobalCounters.global_ops # correctness np.testing.assert_allclose(c1.weight.grad.numpy(), c1_ref.weight.grad.numpy(), atol=5e-4, rtol=1e-5) @@ -1359,6 +1360,7 @@ class TestConvBW(unittest.TestCase): # flops, TODO: This will be fixed once SWIZZLE merges view strides. with self.assertRaises(AssertionError): self.assertEqual(rw_flops, ref_flops) + assert_equiv_uops(compare_ast, ref_ast) @unittest.expectedFailure @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 825a2e87c5..2daa69da68 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -1,6 +1,6 @@ from typing import List import unittest, time -from test.helpers import TestUOps +from test.helpers import assert_equiv_uops from tinygrad import dtypes, Variable, Device from tinygrad.dtype import PtrDType from tinygrad.helpers import DEBUG @@ -142,7 +142,7 @@ class TestGraphRewrite(unittest.TestCase): self.assertEqual(sink.src[1].op, UOps.CONST) self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 3) -class TestUOpGraph(TestUOps): +class TestUOpGraph(unittest.TestCase): def test_add_constant_fold(self): c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) @@ -214,7 +214,7 @@ class TestUOpGraph(TestUOps): # possible val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx)) xyzw = tuple(UOp(UOps.GEP, dtypes.float, (val,), i) for i in range(4)) - self.assert_equiv_uops(_test_vec(xyzw), val) + assert_equiv_uops(_test_vec(xyzw), val) # unaligned val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx)) @@ -242,7 +242,7 @@ class TestUOpGraph(TestUOps): vec = UOp(UOps.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts)) uops = to_uops_list([UOp(UOps.GEP, dtypes.float, (vec,), i) for i in range(vec_size)]) for uop, const in zip(uops, consts): - self.assert_equiv_uops(uop, const) + assert_equiv_uops(uop, const) def test_wmma_vectorize_fold(self): for i in [2, 4, 8]: @@ -251,7 +251,7 @@ class TestUOpGraph(TestUOps): acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) uops = to_uops_list([wmma]) - self.assert_equiv_uops(uops[0], acc) + assert_equiv_uops(uops[0], acc) self.assertEqual(len(uops), 1) for i in [2, 4, 8]: @@ -260,7 +260,7 @@ class TestUOpGraph(TestUOps): acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) uops = to_uops_list([wmma]) - self.assert_equiv_uops(uops[0], acc) + assert_equiv_uops(uops[0], acc) self.assertEqual(len(uops), 1) def test_wmma_vectorize_no_fold(self): @@ -272,7 +272,7 @@ class TestUOpGraph(TestUOps): acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) uops = to_uops_list([wmma]) - self.assert_equiv_uops(uops[-1], wmma) + assert_equiv_uops(uops[-1], wmma) for i in [4, 8]: var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0)) @@ -282,7 +282,7 @@ class TestUOpGraph(TestUOps): acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) uops = to_uops_list([wmma]) - self.assert_equiv_uops(uops[-1], wmma) + assert_equiv_uops(uops[-1], wmma) for i in [2, 4, 8]: vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), @@ -291,7 +291,7 @@ class TestUOpGraph(TestUOps): acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) uops = to_uops_list([wmma]) - self.assert_equiv_uops(uops[-1], wmma) + assert_equiv_uops(uops[-1], wmma) for i in [2, 4, 8]: var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0)) @@ -300,7 +300,7 @@ class TestUOpGraph(TestUOps): acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) uops = to_uops_list([wmma]) - self.assert_equiv_uops(uops[-1], wmma) + assert_equiv_uops(uops[-1], wmma) def test_cast_alu_fold(self): d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=0) @@ -346,9 +346,9 @@ class TestUOpGraph(TestUOps): uops = to_uops_list([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))]) ld0, ld1 = uops[-1].src[2].src # ld0 becomes the invalid value - self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2)) + assert_equiv_uops(ld1, UOp.const(dtypes.int, 2)) # the gate and invalid value are deleted from ld1 - self.assert_equiv_uops(ld0, UOp.load(glbl2, idx, dtype=dtypes.int)) + assert_equiv_uops(ld0, UOp.load(glbl2, idx, dtype=dtypes.int)) def test_fold_gated_load_local(self): glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) @@ -361,9 +361,9 @@ class TestUOpGraph(TestUOps): uops = to_uops_list([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))]) ld0, ld1 = uops[-1].src[2].src # ld0 becomes the invalid value - self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2)) + assert_equiv_uops(ld1, UOp.const(dtypes.int, 2)) # the gate and invalid value are deleted from ld1 - self.assert_equiv_uops(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int)) + assert_equiv_uops(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int)) def test_fold_gated_store(self): glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) @@ -375,7 +375,7 @@ class TestUOpGraph(TestUOps): uops = to_uops_list([st0, st1]) # only the second store happens self.assertEqual(len(uops), 4) - self.assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val)) + assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val)) def test_asserts_bad_gate(self): glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) @@ -611,7 +611,7 @@ class TestLoadStoreFolder(unittest.TestCase): def gate_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + reducer) -class TestIFUOps(TestUOps): +class TestIFUOps(unittest.TestCase): def test_create_ifs(self): gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 4)) @@ -627,7 +627,7 @@ class TestIFUOps(TestUOps): sink = gate_rewrite(sink) if_uops = [u for u in sink.parents if u.op is UOps.IF] self.assertEqual(len(if_uops), 1) - self.assert_equiv_uops(if_uops[0].src[0], gate) + assert_equiv_uops(if_uops[0].src[0], gate) for st in sink.src: self.assertEqual(len(st.src), 3) @@ -645,7 +645,7 @@ class TestIFUOps(TestUOps): sink = gate_rewrite(sink) if_uops = [u for u in sink.parents if u.op is UOps.IF] self.assertEqual(len(if_uops), 1) - self.assert_equiv_uops(if_uops[0].src[0], gate) + assert_equiv_uops(if_uops[0].src[0], gate) for st in sink.src: self.assertEqual(len(st.src), 3) @@ -661,7 +661,7 @@ class TestIFUOps(TestUOps): sink = gate_rewrite(sink) if_uops = [u for u in sink.parents if u.op is UOps.IF] self.assertEqual(len(if_uops), 1) - self.assert_equiv_uops(if_uops[0].src[0], gate) + assert_equiv_uops(if_uops[0].src[0], gate) for st in sink.src: self.assertEqual(len(st.src), 3) diff --git a/test/test_uops.py b/test/test_uops.py index 0302da3bd4..a419a4aafa 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -11,7 +11,7 @@ from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite from tinygrad.shape.symbolic import Variable -from test.helpers import is_dtype_supported, TestUOps as TestEqUOps +from test.helpers import is_dtype_supported, assert_equiv_uops def to_uops_list(u:List[UOp], opts=None, skip_check=False) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check) @@ -356,7 +356,7 @@ class TestUOpCompare(unittest.TestCase): mul = UOp(UOps.ALU, dtypes.float, (a, b), BinaryOps.MUL) assert (add < mul) or (mul < add), "add and mul with same src should have an order" -class TestUOpStr(TestEqUOps): +class TestUOpStr(unittest.TestCase): def test_uop_str(self): a = UOp(UOps.CONST, dtypes.float, (), 2.0) + UOp(UOps.CONST, dtypes.float, (), 3.0) for _ in range(20): a = a + a @@ -368,7 +368,7 @@ class TestUOpStr(TestEqUOps): # nice big complicated uop with Context(NOOPT=1): sink = UOp(UOps.SINK, None, (get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops[-1],)) - self.assert_equiv_uops(sink, eval(str(sink))) + assert_equiv_uops(sink, eval(str(sink))) def test_nop_str(self): a = NOp(UOps.CONST, dtypes.float, (), 2.0, name="c0") + NOp(UOps.CONST, dtypes.float, (), 3.0, name="c1")