use is to compare with enum (#3993)

* use is to compare with enum

currently it's mixed between `==` and `is`, moved all to `is`

* more
This commit is contained in:
chenyu
2024-03-29 13:02:56 -04:00
committed by GitHub
parent 0affbbf81c
commit d9ff636cf5
13 changed files with 90 additions and 90 deletions

View File

@@ -170,7 +170,7 @@ def fuzz_linearizer(lin: Linearizer):
def _is_simple(lin: Linearizer) -> bool:
if len(lin.ast) > 1: return False
ast:LazyOp = lin.ast[0]
if ast.src[0] and ast.src[0].op == UnaryOps.CAST and ast.src[0].src[0] and ast.src[0].src[0].op == BufferOps.LOAD: return True
if ast.src[0] and ast.src[0].op is UnaryOps.CAST and ast.src[0].src[0] and ast.src[0].src[0].op is BufferOps.LOAD: return True
return False
if __name__ == "__main__":

View File

@@ -60,7 +60,7 @@ class TestLinearizer(unittest.TestCase):
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_loads = len([uop for uop in k.uops if uop.uop == UOps.LOAD])
num_loads = len([uop for uop in k.uops if uop.uop is UOps.LOAD])
assert num_loads <= 4, "more load uops than needed"
assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?"
@@ -93,7 +93,7 @@ class TestLinearizer(unittest.TestCase):
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
num_ops = len([uop for uop in k.uops if uop.uop is UOps.ALU])
assert num_ops <= 1, "more alu uops than needed"
def test_reduce_upcast(self):
@@ -106,8 +106,8 @@ class TestLinearizer(unittest.TestCase):
k.upcast()
k.upcast()
k.linearize()
accs = [u for u in k.uops if u.uop == UOps.DEFINE_ACC]
stores = [u for u in k.uops if u.uop == UOps.STORE]
accs = [u for u in k.uops if u.uop is UOps.DEFINE_ACC]
stores = [u for u in k.uops if u.uop is UOps.STORE]
assert len(accs) == 1
assert len(stores) == 1
assert stores[0].vin[-1].dtype == accs[0].dtype == dtypes.float.vec(4)
@@ -122,15 +122,15 @@ class TestLinearizer(unittest.TestCase):
k.hand_coded_optimizations()
k.linearize()
accs = [u for u in k.uops if u.uop == UOps.DEFINE_ACC]
stores = [u for u in k.uops if u.uop == UOps.STORE]
accs = [u for u in k.uops if u.uop is UOps.DEFINE_ACC]
stores = [u for u in k.uops if u.uop is UOps.STORE]
# the first store is to lds and can be upcasted
assert accs[0].dtype == stores[0].vin[-1].dtype == dtypes.float.vec(4)
assert stores[0].vin[0].uop == UOps.DEFINE_LOCAL
assert stores[0].vin[0].uop is UOps.DEFINE_LOCAL
# the second store is to gds with no upcasts
assert accs[1].dtype == stores[1].vin[-1].dtype == dtypes.float
assert stores[1].vin[0].uop == UOps.DEFINE_GLOBAL
assert stores[1].vin[0].uop is UOps.DEFINE_GLOBAL
def test_zero_fold(self):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
@@ -139,7 +139,7 @@ class TestLinearizer(unittest.TestCase):
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
num_ops = len([uop for uop in k.uops if uop.uop is UOps.ALU])
assert num_ops == 0, "more alu uops than needed"
def test_constant_fold(self):
@@ -157,14 +157,14 @@ class TestLinearizer(unittest.TestCase):
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
k = Linearizer(*create_schedule([a.lazydata])[-1].ast)
k.linearize()
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
local = [uop for uop in k.uops if uop.uop is UOps.DEFINE_ACC]
assert local[0].dtype == acc_dtype
def test_arg_acc_dtype(self):
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
k = Linearizer(*create_schedule([c.lazydata])[-1].ast)
k.linearize()
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
local = [uop for uop in k.uops if uop.uop is UOps.DEFINE_ACC]
assert local[0].dtype == expected_dtype
tests = (
@@ -194,8 +194,8 @@ class TestLinearizer(unittest.TestCase):
k = Linearizer(realized_ast)
k.apply_tensor_cores(1)
k.linearize()
assert len([uop for uop in k.uops if uop.uop == UOps.WMMA]) == 1, "tensor core not triggered"
assert len([x for x in k.applied_opts if x.op == OptOps.TC]) == 1, "tensor core opt not included"
assert len([uop for uop in k.uops if uop.uop is UOps.WMMA]) == 1, "tensor core not triggered"
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
np_c = np_a @ np_b
(tc_atol, tc_rtol) = (1e-2, 1e-3) if tc.dtype_out == dtypes.half else (5e-3, 1e-4)
np.testing.assert_allclose(np_c, r.numpy(), atol=tc_atol, rtol=tc_rtol)
@@ -213,7 +213,7 @@ class TestLinearizer(unittest.TestCase):
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]
assert len(sched) == 1
lin = Linearizer(*sched[0].ast)
assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse"
assert not any(u.uop is UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse"
def test_assign_fold(self):
a = Tensor.ones(4, 4).contiguous().realize()
@@ -264,8 +264,8 @@ class TestLinearizer(unittest.TestCase):
if if_op:=next((u for u in uops if u.uop is UOps.IF), None):
uops = uops[:uops.index(if_op)]
assert len(set([u.uop for u in uops if u.uop in {UOps.LOOP, UOps.SPECIAL}])) == 1, "has either specials or loops, not both"
assert len([u for u in uops if u.uop == UOps.PHI]) == 0, "PHI should have been simplified"
assert len([u for u in uops if u.arg == BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops"
assert len([u for u in uops if u.uop is UOps.PHI]) == 0, "PHI should have been simplified"
assert len([u for u in uops if u.arg is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops"
helper(Tensor.arange(5.5, (3.5*300), 3.5))
helper(Tensor.arange(-1, -100, -5))
@@ -285,8 +285,8 @@ def helper_realized_ast(r:Tensor):
class TestFloat4(unittest.TestCase):
@staticmethod
def count_float4(k):
return (len([uop for uop in k.uops if uop.uop == UOps.LOAD and uop.dtype == dtypes.float.vec(4)]),
len([uop for uop in k.uops if uop.uop == UOps.STORE and len(uop.vin) == 3 and uop.vin[2].dtype == dtypes.float.vec(4)]))
return (len([uop for uop in k.uops if uop.uop is UOps.LOAD and uop.dtype == dtypes.float.vec(4)]),
len([uop for uop in k.uops if uop.uop is UOps.STORE and len(uop.vin) == 3 and uop.vin[2].dtype == dtypes.float.vec(4)]))
# TODO: express opts below as auto opts
@@ -831,7 +831,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
# check that the float4 cast collapses
store_vals = [u.vin[-1] for u in k.uops if u.uop is UOps.STORE]
for val in store_vals:
assert val.dtype == dtypes.float.vec(4) and val.uop != UOps.CAST
assert val.dtype == dtypes.float.vec(4) and val.uop is not UOps.CAST
@unittest.skipUnless(Device[Device.DEFAULT].compiler.compiler_opts.supports_float4, "device doesn't support float4")
def test_grouped_store_values(self):
@@ -843,7 +843,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
k.linearize()
store_val = [u.vin[-1] for u in k.uops if u.uop is UOps.STORE][0]
assert store_val.dtype == dtypes.float.vec(4) and store_val.uop != UOps.CAST
assert store_val.dtype == dtypes.float.vec(4) and store_val.uop is not UOps.CAST
def test_grouped_store_locals_and_globals(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local or not Device[Device.DEFAULT].compiler.compiler_opts.has_shared or \
@@ -865,7 +865,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
# check that the float4 cast collapses for all stores
for store in local_stores+global_stores:
assert store.vin[-1].dtype == dtypes.float.vec(2) and store.vin[-1].uop != UOps.CAST
assert store.vin[-1].dtype == dtypes.float.vec(2) and store.vin[-1].uop is not UOps.CAST
# check the children's vins
assert barrier.vin == tuple(local_stores)
assert len([u for u in k.uops if u.uop is UOps.IF and u.vin[-1] == barrier]) == 1
@@ -881,11 +881,11 @@ class TestLinearizerUOptimize(unittest.TestCase):
k.hand_coded_optimizations()
k.linearize()
stores = [u for u in k.uops if u.uop == UOps.STORE]
stores = [u for u in k.uops if u.uop is UOps.STORE]
# the float4 value stores directly in lds and we skip upcast
assert stores[0].vin[-1].dtype == dtypes.float.vec(4)
assert stores[0].vin[-1].uop != UOps.CAST
assert stores[0].vin[-1].uop is not UOps.CAST
# the global store doesn't change
assert stores[1].vin[-1].dtype == dtypes.float
@@ -903,7 +903,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
for opt in opts: k.apply_opt(opt)
k.linearize()
out = [u for u in k.uops if u.uop == UOps.STORE][0]
out = [u for u in k.uops if u.uop is UOps.STORE][0]
assert out.vin[-1].uop is UOps.CAST and out.vin[-1].dtype == dtypes.float.vec(4)
def test_skip_unmatching_upcasts_with_gep(self):
@@ -918,7 +918,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
for opt in opts: k.apply_opt(opt)
k.linearize()
out = [u for u in k.uops if u.uop == UOps.STORE][0]
out = [u for u in k.uops if u.uop is UOps.STORE][0]
assert out.vin[-1].uop is UOps.CAST and out.vin[-1].dtype == dtypes.float.vec(2)