mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
use is to compare with enum (#3993)
* use is to compare with enum currently it's mixed between `==` and `is`, moved all to `is` * more
This commit is contained in:
2
test/external/fuzz_linearizer.py
vendored
2
test/external/fuzz_linearizer.py
vendored
@@ -170,7 +170,7 @@ def fuzz_linearizer(lin: Linearizer):
|
||||
def _is_simple(lin: Linearizer) -> bool:
|
||||
if len(lin.ast) > 1: return False
|
||||
ast:LazyOp = lin.ast[0]
|
||||
if ast.src[0] and ast.src[0].op == UnaryOps.CAST and ast.src[0].src[0] and ast.src[0].src[0].op == BufferOps.LOAD: return True
|
||||
if ast.src[0] and ast.src[0].op is UnaryOps.CAST and ast.src[0].src[0] and ast.src[0].src[0].op is BufferOps.LOAD: return True
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -60,7 +60,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_loads = len([uop for uop in k.uops if uop.uop == UOps.LOAD])
|
||||
num_loads = len([uop for uop in k.uops if uop.uop is UOps.LOAD])
|
||||
assert num_loads <= 4, "more load uops than needed"
|
||||
assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?"
|
||||
|
||||
@@ -93,7 +93,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
|
||||
num_ops = len([uop for uop in k.uops if uop.uop is UOps.ALU])
|
||||
assert num_ops <= 1, "more alu uops than needed"
|
||||
|
||||
def test_reduce_upcast(self):
|
||||
@@ -106,8 +106,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
k.upcast()
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
accs = [u for u in k.uops if u.uop == UOps.DEFINE_ACC]
|
||||
stores = [u for u in k.uops if u.uop == UOps.STORE]
|
||||
accs = [u for u in k.uops if u.uop is UOps.DEFINE_ACC]
|
||||
stores = [u for u in k.uops if u.uop is UOps.STORE]
|
||||
assert len(accs) == 1
|
||||
assert len(stores) == 1
|
||||
assert stores[0].vin[-1].dtype == accs[0].dtype == dtypes.float.vec(4)
|
||||
@@ -122,15 +122,15 @@ class TestLinearizer(unittest.TestCase):
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
|
||||
accs = [u for u in k.uops if u.uop == UOps.DEFINE_ACC]
|
||||
stores = [u for u in k.uops if u.uop == UOps.STORE]
|
||||
accs = [u for u in k.uops if u.uop is UOps.DEFINE_ACC]
|
||||
stores = [u for u in k.uops if u.uop is UOps.STORE]
|
||||
|
||||
# the first store is to lds and can be upcasted
|
||||
assert accs[0].dtype == stores[0].vin[-1].dtype == dtypes.float.vec(4)
|
||||
assert stores[0].vin[0].uop == UOps.DEFINE_LOCAL
|
||||
assert stores[0].vin[0].uop is UOps.DEFINE_LOCAL
|
||||
# the second store is to gds with no upcasts
|
||||
assert accs[1].dtype == stores[1].vin[-1].dtype == dtypes.float
|
||||
assert stores[1].vin[0].uop == UOps.DEFINE_GLOBAL
|
||||
assert stores[1].vin[0].uop is UOps.DEFINE_GLOBAL
|
||||
|
||||
def test_zero_fold(self):
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
@@ -139,7 +139,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
|
||||
num_ops = len([uop for uop in k.uops if uop.uop is UOps.ALU])
|
||||
assert num_ops == 0, "more alu uops than needed"
|
||||
|
||||
def test_constant_fold(self):
|
||||
@@ -157,14 +157,14 @@ class TestLinearizer(unittest.TestCase):
|
||||
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
|
||||
k = Linearizer(*create_schedule([a.lazydata])[-1].ast)
|
||||
k.linearize()
|
||||
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
|
||||
local = [uop for uop in k.uops if uop.uop is UOps.DEFINE_ACC]
|
||||
assert local[0].dtype == acc_dtype
|
||||
|
||||
def test_arg_acc_dtype(self):
|
||||
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
|
||||
k = Linearizer(*create_schedule([c.lazydata])[-1].ast)
|
||||
k.linearize()
|
||||
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
|
||||
local = [uop for uop in k.uops if uop.uop is UOps.DEFINE_ACC]
|
||||
assert local[0].dtype == expected_dtype
|
||||
|
||||
tests = (
|
||||
@@ -194,8 +194,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
k = Linearizer(realized_ast)
|
||||
k.apply_tensor_cores(1)
|
||||
k.linearize()
|
||||
assert len([uop for uop in k.uops if uop.uop == UOps.WMMA]) == 1, "tensor core not triggered"
|
||||
assert len([x for x in k.applied_opts if x.op == OptOps.TC]) == 1, "tensor core opt not included"
|
||||
assert len([uop for uop in k.uops if uop.uop is UOps.WMMA]) == 1, "tensor core not triggered"
|
||||
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
|
||||
np_c = np_a @ np_b
|
||||
(tc_atol, tc_rtol) = (1e-2, 1e-3) if tc.dtype_out == dtypes.half else (5e-3, 1e-4)
|
||||
np.testing.assert_allclose(np_c, r.numpy(), atol=tc_atol, rtol=tc_rtol)
|
||||
@@ -213,7 +213,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]
|
||||
assert len(sched) == 1
|
||||
lin = Linearizer(*sched[0].ast)
|
||||
assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse"
|
||||
assert not any(u.uop is UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse"
|
||||
|
||||
def test_assign_fold(self):
|
||||
a = Tensor.ones(4, 4).contiguous().realize()
|
||||
@@ -264,8 +264,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
if if_op:=next((u for u in uops if u.uop is UOps.IF), None):
|
||||
uops = uops[:uops.index(if_op)]
|
||||
assert len(set([u.uop for u in uops if u.uop in {UOps.LOOP, UOps.SPECIAL}])) == 1, "has either specials or loops, not both"
|
||||
assert len([u for u in uops if u.uop == UOps.PHI]) == 0, "PHI should have been simplified"
|
||||
assert len([u for u in uops if u.arg == BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops"
|
||||
assert len([u for u in uops if u.uop is UOps.PHI]) == 0, "PHI should have been simplified"
|
||||
assert len([u for u in uops if u.arg is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops"
|
||||
|
||||
helper(Tensor.arange(5.5, (3.5*300), 3.5))
|
||||
helper(Tensor.arange(-1, -100, -5))
|
||||
@@ -285,8 +285,8 @@ def helper_realized_ast(r:Tensor):
|
||||
class TestFloat4(unittest.TestCase):
|
||||
@staticmethod
|
||||
def count_float4(k):
|
||||
return (len([uop for uop in k.uops if uop.uop == UOps.LOAD and uop.dtype == dtypes.float.vec(4)]),
|
||||
len([uop for uop in k.uops if uop.uop == UOps.STORE and len(uop.vin) == 3 and uop.vin[2].dtype == dtypes.float.vec(4)]))
|
||||
return (len([uop for uop in k.uops if uop.uop is UOps.LOAD and uop.dtype == dtypes.float.vec(4)]),
|
||||
len([uop for uop in k.uops if uop.uop is UOps.STORE and len(uop.vin) == 3 and uop.vin[2].dtype == dtypes.float.vec(4)]))
|
||||
|
||||
# TODO: express opts below as auto opts
|
||||
|
||||
@@ -831,7 +831,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
|
||||
# check that the float4 cast collapses
|
||||
store_vals = [u.vin[-1] for u in k.uops if u.uop is UOps.STORE]
|
||||
for val in store_vals:
|
||||
assert val.dtype == dtypes.float.vec(4) and val.uop != UOps.CAST
|
||||
assert val.dtype == dtypes.float.vec(4) and val.uop is not UOps.CAST
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].compiler.compiler_opts.supports_float4, "device doesn't support float4")
|
||||
def test_grouped_store_values(self):
|
||||
@@ -843,7 +843,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
|
||||
k.linearize()
|
||||
|
||||
store_val = [u.vin[-1] for u in k.uops if u.uop is UOps.STORE][0]
|
||||
assert store_val.dtype == dtypes.float.vec(4) and store_val.uop != UOps.CAST
|
||||
assert store_val.dtype == dtypes.float.vec(4) and store_val.uop is not UOps.CAST
|
||||
|
||||
def test_grouped_store_locals_and_globals(self):
|
||||
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local or not Device[Device.DEFAULT].compiler.compiler_opts.has_shared or \
|
||||
@@ -865,7 +865,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
|
||||
|
||||
# check that the float4 cast collapses for all stores
|
||||
for store in local_stores+global_stores:
|
||||
assert store.vin[-1].dtype == dtypes.float.vec(2) and store.vin[-1].uop != UOps.CAST
|
||||
assert store.vin[-1].dtype == dtypes.float.vec(2) and store.vin[-1].uop is not UOps.CAST
|
||||
# check the children's vins
|
||||
assert barrier.vin == tuple(local_stores)
|
||||
assert len([u for u in k.uops if u.uop is UOps.IF and u.vin[-1] == barrier]) == 1
|
||||
@@ -881,11 +881,11 @@ class TestLinearizerUOptimize(unittest.TestCase):
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
|
||||
stores = [u for u in k.uops if u.uop == UOps.STORE]
|
||||
stores = [u for u in k.uops if u.uop is UOps.STORE]
|
||||
|
||||
# the float4 value stores directly in lds and we skip upcast
|
||||
assert stores[0].vin[-1].dtype == dtypes.float.vec(4)
|
||||
assert stores[0].vin[-1].uop != UOps.CAST
|
||||
assert stores[0].vin[-1].uop is not UOps.CAST
|
||||
|
||||
# the global store doesn't change
|
||||
assert stores[1].vin[-1].dtype == dtypes.float
|
||||
@@ -903,7 +903,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.linearize()
|
||||
|
||||
out = [u for u in k.uops if u.uop == UOps.STORE][0]
|
||||
out = [u for u in k.uops if u.uop is UOps.STORE][0]
|
||||
assert out.vin[-1].uop is UOps.CAST and out.vin[-1].dtype == dtypes.float.vec(4)
|
||||
|
||||
def test_skip_unmatching_upcasts_with_gep(self):
|
||||
@@ -918,7 +918,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.linearize()
|
||||
|
||||
out = [u for u in k.uops if u.uop == UOps.STORE][0]
|
||||
out = [u for u in k.uops if u.uop is UOps.STORE][0]
|
||||
assert out.vin[-1].uop is UOps.CAST and out.vin[-1].dtype == dtypes.float.vec(2)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user