cleanup uop eq asserts for swizzle [run_process_replay] (#6362)

* cleanup uop eq asserts for swizzle [run_process_replay]

* more stuff
This commit is contained in:
qazal
2024-09-05 13:36:36 +08:00
committed by GitHub
parent 72be31cb56
commit e7f6b654ad
7 changed files with 44 additions and 54 deletions

View File

@@ -1,6 +1,7 @@
import sys, unittest, time
from typing import Callable, Optional, Set, Tuple, TypeVar
import sys, time
from typing import Callable, Tuple, TypeVar
import numpy as np
from test.external.process_replay.helpers import print_diff
from tinygrad import Tensor, Device, dtypes
from tinygrad.ops import UOp, UOps
from tinygrad.shape.shapetracker import ShapeTracker
@@ -56,22 +57,10 @@ def rand_for_dtype(dt:DType, size:int):
return np.random.choice([True, False], size=size)
return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt))
class TestUOps(unittest.TestCase):
def assert_equiv_uops(self, uop1:UOp, uop2:UOp, cache:Optional[Set[Tuple[UOp, UOp]]]=None):
if cache is None: cache = set()
if (uop1, uop2) in cache: return
cache.add((uop1, uop2))
# NOTE: direct UOps __eq__ is comparing object reference, use this function to compare two uops
try:
self.assertIs(uop1.op, uop2.op)
self.assertEqual(uop1.dtype, uop2.dtype)
self.assertEqual(uop1.arg, uop2.arg)
self.assertEqual(len(uop1.src), len(uop2.src))
for s1, s2 in zip(uop1.src, uop2.src): self.assert_equiv_uops(s1, s2, cache)
except AssertionError as e:
print(f"{uop1=}")
print(f"{uop2=}")
raise e
def assert_equiv_uops(u1:UOp, u2:UOp) -> None:
if u1.key != u2.key:
print_diff(u1, u2)
raise AssertionError("uops aren't equal.")
def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]) -> UOp:
return UOp(UOps.CONST, dtype, (ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),),

View File

@@ -1,12 +1,12 @@
import unittest
import time
import numpy as np
from test.helpers import TestUOps
from test.helpers import assert_equiv_uops
from tinygrad import Tensor, dtypes
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import lower_schedule_item, run_schedule
class TestFusionOp(TestUOps):
class TestFusionOp(unittest.TestCase):
def test_contiguous_add(self):
def test(contig=False):
bt = Tensor(np.arange(16), dtype=dtypes.float32).reshape(4,4)
@@ -43,8 +43,8 @@ class TestFusionOp(TestUOps):
c = Tensor([1,2,3,4])
for _ in range(23): c = c + c
sched3 = create_schedule([c.lazydata], None)
self.assert_equiv_uops(sched1[-1].ast, sched2[-1].ast)
with self.assertRaises(AssertionError): self.assert_equiv_uops(sched1[-1].ast, sched3[-1].ast)
assert_equiv_uops(sched1[-1].ast, sched2[-1].ast)
with self.assertRaises(AssertionError): assert_equiv_uops(sched1[-1].ast, sched3[-1].ast)
self.assertLess(time.perf_counter()-st, 2.0)
if __name__ == '__main__':

View File

@@ -1,9 +1,8 @@
import unittest, itertools
from test.helpers import TestUOps
from tinygrad.dtype import dtypes
from tinygrad.ops import UOps, UOp, PatternMatcher, UPat, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401
class TestPatternMatcher(TestUOps):
class TestPatternMatcher(unittest.TestCase):
def test_simple_match(self):
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)

View File

@@ -1,10 +1,10 @@
import unittest, pickle
import numpy as np
from test.helpers import TestUOps
from test.helpers import assert_equiv_uops
from tinygrad import Tensor, TinyJit, Variable
from tinygrad.engine.schedule import create_schedule
class TestPickle(TestUOps):
class TestPickle(unittest.TestCase):
def test_pickle_realized_tensor(self):
t = Tensor.rand(10, 10).realize()
st = pickle.dumps(t)
@@ -64,7 +64,7 @@ class TestPickle(TestUOps):
sched = create_schedule([out.lazydata])
pk = pickle.dumps(sched)
sched_pk = pickle.loads(pk)
self.assert_equiv_uops(sched_pk[-1].ast, sched[-1].ast)
assert_equiv_uops(sched_pk[-1].ast, sched[-1].ast)
class TestPickleJIT(unittest.TestCase):
@classmethod

View File

@@ -14,9 +14,9 @@ from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite
from tinygrad.helpers import AST_REWRITE, CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import create_schedule, get_output_st, reduceop_fusor, st_fixup
from tinygrad.engine.schedule import create_schedule, get_output_st, reduceop_fusor, st_fixup, ScheduleItem
from tinygrad.engine.realize import CompiledRunner, run_schedule
from test.helpers import is_dtype_supported, Context, timeit
from test.helpers import assert_equiv_uops, is_dtype_supported, Context, timeit
from tinygrad.lazy import LazyBuffer, view_supported_devices
from extra.models.llama import precompute_freqs_cis
@@ -1299,7 +1299,7 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(out, 3)) # TODO: push a reduceop through a reshape
class TestConvBW(unittest.TestCase):
def check_schedule(self, xt, cnt:int, flops=None):
def check_schedule(self, xt, cnt:int, flops=None) -> List[ScheduleItem]:
with Context(FUSE_CONV_BW=getenv("FUSE_CONV_BW", 1), NOOPT=flops is not None):
s = create_schedule(flatten([r.lazydata.lbs for r in xt]))
kernels = [si for si in s if si.ast.op is UOps.SINK]
@@ -1308,6 +1308,7 @@ class TestConvBW(unittest.TestCase):
run_schedule(s)
if flops is not None: assert GlobalCounters.global_ops <= flops, f"too many ops {GlobalCounters.global_ops}"
if FUSE_CONV_BW: self.assertEqual(len(kernels), cnt)
return kernels
def test_fold_conv_relu_backward(self):
c1 = nn.Conv2d(3,16,3, bias=False)
@@ -1342,7 +1343,7 @@ class TestConvBW(unittest.TestCase):
img = Tensor(img_np, requires_grad=True)
c1(img).relu().mean().backward()
assert img.grad is not None and c1.weight.grad is not None
with Context(AST_REWRITE=1): self.check_schedule([img.grad, c1.weight.grad], 3)
with Context(AST_REWRITE=1): compare_ast = self.check_schedule([img.grad, c1.weight.grad], 3)[1].ast
rw_flops = GlobalCounters.global_ops
# ref
GlobalCounters.reset()
@@ -1351,7 +1352,7 @@ class TestConvBW(unittest.TestCase):
img_ref = Tensor(img_np, requires_grad=True)
c1_ref(img_ref).relu().mean().backward()
assert img_ref.grad is not None and c1_ref.weight.grad is not None
with Context(AST_REWRITE=0): self.check_schedule([img_ref.grad, c1_ref.weight.grad], 3)
with Context(AST_REWRITE=0): ref_ast = self.check_schedule([img_ref.grad, c1_ref.weight.grad], 3)[1].ast
ref_flops = GlobalCounters.global_ops
# correctness
np.testing.assert_allclose(c1.weight.grad.numpy(), c1_ref.weight.grad.numpy(), atol=5e-4, rtol=1e-5)
@@ -1359,6 +1360,7 @@ class TestConvBW(unittest.TestCase):
# flops, TODO: This will be fixed once SWIZZLE merges view strides.
with self.assertRaises(AssertionError):
self.assertEqual(rw_flops, ref_flops)
assert_equiv_uops(compare_ast, ref_ast)
@unittest.expectedFailure
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")

View File

@@ -1,6 +1,6 @@
from typing import List
import unittest, time
from test.helpers import TestUOps
from test.helpers import assert_equiv_uops
from tinygrad import dtypes, Variable, Device
from tinygrad.dtype import PtrDType
from tinygrad.helpers import DEBUG
@@ -142,7 +142,7 @@ class TestGraphRewrite(unittest.TestCase):
self.assertEqual(sink.src[1].op, UOps.CONST)
self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 3)
class TestUOpGraph(TestUOps):
class TestUOpGraph(unittest.TestCase):
def test_add_constant_fold(self):
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
@@ -214,7 +214,7 @@ class TestUOpGraph(TestUOps):
# possible
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
xyzw = tuple(UOp(UOps.GEP, dtypes.float, (val,), i) for i in range(4))
self.assert_equiv_uops(_test_vec(xyzw), val)
assert_equiv_uops(_test_vec(xyzw), val)
# unaligned
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
@@ -242,7 +242,7 @@ class TestUOpGraph(TestUOps):
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
uops = to_uops_list([UOp(UOps.GEP, dtypes.float, (vec,), i) for i in range(vec_size)])
for uop, const in zip(uops, consts):
self.assert_equiv_uops(uop, const)
assert_equiv_uops(uop, const)
def test_wmma_vectorize_fold(self):
for i in [2, 4, 8]:
@@ -251,7 +251,7 @@ class TestUOpGraph(TestUOps):
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[0], acc)
assert_equiv_uops(uops[0], acc)
self.assertEqual(len(uops), 1)
for i in [2, 4, 8]:
@@ -260,7 +260,7 @@ class TestUOpGraph(TestUOps):
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[0], acc)
assert_equiv_uops(uops[0], acc)
self.assertEqual(len(uops), 1)
def test_wmma_vectorize_no_fold(self):
@@ -272,7 +272,7 @@ class TestUOpGraph(TestUOps):
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[-1], wmma)
assert_equiv_uops(uops[-1], wmma)
for i in [4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
@@ -282,7 +282,7 @@ class TestUOpGraph(TestUOps):
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[-1], wmma)
assert_equiv_uops(uops[-1], wmma)
for i in [2, 4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
@@ -291,7 +291,7 @@ class TestUOpGraph(TestUOps):
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[-1], wmma)
assert_equiv_uops(uops[-1], wmma)
for i in [2, 4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
@@ -300,7 +300,7 @@ class TestUOpGraph(TestUOps):
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
self.assert_equiv_uops(uops[-1], wmma)
assert_equiv_uops(uops[-1], wmma)
def test_cast_alu_fold(self):
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=0)
@@ -346,9 +346,9 @@ class TestUOpGraph(TestUOps):
uops = to_uops_list([UOp(UOps.STORE, None, (glbl0, idx, ld1+ld0))])
ld0, ld1 = uops[-1].src[2].src
# ld0 becomes the invalid value
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
# the gate and invalid value are deleted from ld1
self.assert_equiv_uops(ld0, UOp.load(glbl2, idx, dtype=dtypes.int))
assert_equiv_uops(ld0, UOp.load(glbl2, idx, dtype=dtypes.int))
def test_fold_gated_load_local(self):
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
@@ -361,9 +361,9 @@ class TestUOpGraph(TestUOps):
uops = to_uops_list([UOp(UOps.STORE, None, (glbl0, lidx, ld1+ld0))])
ld0, ld1 = uops[-1].src[2].src
# ld0 becomes the invalid value
self.assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
assert_equiv_uops(ld1, UOp.const(dtypes.int, 2))
# the gate and invalid value are deleted from ld1
self.assert_equiv_uops(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int))
assert_equiv_uops(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int))
def test_fold_gated_store(self):
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
@@ -375,7 +375,7 @@ class TestUOpGraph(TestUOps):
uops = to_uops_list([st0, st1])
# only the second store happens
self.assertEqual(len(uops), 4)
self.assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val))
assert_equiv_uops(uops[-1], UOp.store(glbl, idx1, val))
def test_asserts_bad_gate(self):
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
@@ -611,7 +611,7 @@ class TestLoadStoreFolder(unittest.TestCase):
def gate_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + reducer)
class TestIFUOps(TestUOps):
class TestIFUOps(unittest.TestCase):
def test_create_ifs(self):
gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 4))
@@ -627,7 +627,7 @@ class TestIFUOps(TestUOps):
sink = gate_rewrite(sink)
if_uops = [u for u in sink.parents if u.op is UOps.IF]
self.assertEqual(len(if_uops), 1)
self.assert_equiv_uops(if_uops[0].src[0], gate)
assert_equiv_uops(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 3)
@@ -645,7 +645,7 @@ class TestIFUOps(TestUOps):
sink = gate_rewrite(sink)
if_uops = [u for u in sink.parents if u.op is UOps.IF]
self.assertEqual(len(if_uops), 1)
self.assert_equiv_uops(if_uops[0].src[0], gate)
assert_equiv_uops(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 3)
@@ -661,7 +661,7 @@ class TestIFUOps(TestUOps):
sink = gate_rewrite(sink)
if_uops = [u for u in sink.parents if u.op is UOps.IF]
self.assertEqual(len(if_uops), 1)
self.assert_equiv_uops(if_uops[0].src[0], gate)
assert_equiv_uops(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 3)

View File

@@ -11,7 +11,7 @@ from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite
from tinygrad.shape.symbolic import Variable
from test.helpers import is_dtype_supported, TestUOps as TestEqUOps
from test.helpers import is_dtype_supported, assert_equiv_uops
def to_uops_list(u:List[UOp], opts=None, skip_check=False) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check)
@@ -356,7 +356,7 @@ class TestUOpCompare(unittest.TestCase):
mul = UOp(UOps.ALU, dtypes.float, (a, b), BinaryOps.MUL)
assert (add < mul) or (mul < add), "add and mul with same src should have an order"
class TestUOpStr(TestEqUOps):
class TestUOpStr(unittest.TestCase):
def test_uop_str(self):
a = UOp(UOps.CONST, dtypes.float, (), 2.0) + UOp(UOps.CONST, dtypes.float, (), 3.0)
for _ in range(20): a = a + a
@@ -368,7 +368,7 @@ class TestUOpStr(TestEqUOps):
# nice big complicated uop
with Context(NOOPT=1):
sink = UOp(UOps.SINK, None, (get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops[-1],))
self.assert_equiv_uops(sink, eval(str(sink)))
assert_equiv_uops(sink, eval(str(sink)))
def test_nop_str(self):
a = NOp(UOps.CONST, dtypes.float, (), 2.0, name="c0") + NOp(UOps.CONST, dtypes.float, (), 3.0, name="c1")