From d9ff636cf5130e47d6471e20d420133f418bf54e Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 29 Mar 2024 13:02:56 -0400 Subject: [PATCH] use is to compare with enum (#3993) * use is to compare with enum currently it's mixed between `==` and `is`, moved all to `is` * more --- test/external/fuzz_linearizer.py | 2 +- test/test_linearizer.py | 50 ++++++++++++++++---------------- tinygrad/codegen/kernel.py | 38 ++++++++++++------------ tinygrad/codegen/linearizer.py | 8 ++--- tinygrad/codegen/uops.py | 16 +++++----- tinygrad/engine/schedule.py | 2 +- tinygrad/features/graph.py | 4 +-- tinygrad/lazy.py | 4 +-- tinygrad/ops.py | 2 +- tinygrad/renderer/assembly.py | 34 +++++++++++----------- tinygrad/renderer/cstyle.py | 10 +++---- tinygrad/renderer/llvmir.py | 4 +-- tinygrad/runtime/ops_disk.py | 6 ++-- 13 files changed, 90 insertions(+), 90 deletions(-) diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 11f9b24ed0..80fed448de 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -170,7 +170,7 @@ def fuzz_linearizer(lin: Linearizer): def _is_simple(lin: Linearizer) -> bool: if len(lin.ast) > 1: return False ast:LazyOp = lin.ast[0] - if ast.src[0] and ast.src[0].op == UnaryOps.CAST and ast.src[0].src[0] and ast.src[0].src[0].op == BufferOps.LOAD: return True + if ast.src[0] and ast.src[0].op is UnaryOps.CAST and ast.src[0].src[0] and ast.src[0].src[0].op is BufferOps.LOAD: return True return False if __name__ == "__main__": diff --git a/test/test_linearizer.py b/test/test_linearizer.py index cd9dc70019..7d38f7dbd3 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -60,7 +60,7 @@ class TestLinearizer(unittest.TestCase): k = Linearizer(*create_schedule([r.lazydata])[-1].ast) k.upcast() k.linearize() - num_loads = len([uop for uop in k.uops if uop.uop == UOps.LOAD]) + num_loads = len([uop for uop in k.uops if uop.uop is UOps.LOAD]) assert num_loads <= 4, "more load uops than needed" assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?" @@ -93,7 +93,7 @@ class TestLinearizer(unittest.TestCase): k = Linearizer(*create_schedule([r.lazydata])[-1].ast) k.upcast() k.linearize() - num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU]) + num_ops = len([uop for uop in k.uops if uop.uop is UOps.ALU]) assert num_ops <= 1, "more alu uops than needed" def test_reduce_upcast(self): @@ -106,8 +106,8 @@ class TestLinearizer(unittest.TestCase): k.upcast() k.upcast() k.linearize() - accs = [u for u in k.uops if u.uop == UOps.DEFINE_ACC] - stores = [u for u in k.uops if u.uop == UOps.STORE] + accs = [u for u in k.uops if u.uop is UOps.DEFINE_ACC] + stores = [u for u in k.uops if u.uop is UOps.STORE] assert len(accs) == 1 assert len(stores) == 1 assert stores[0].vin[-1].dtype == accs[0].dtype == dtypes.float.vec(4) @@ -122,15 +122,15 @@ class TestLinearizer(unittest.TestCase): k.hand_coded_optimizations() k.linearize() - accs = [u for u in k.uops if u.uop == UOps.DEFINE_ACC] - stores = [u for u in k.uops if u.uop == UOps.STORE] + accs = [u for u in k.uops if u.uop is UOps.DEFINE_ACC] + stores = [u for u in k.uops if u.uop is UOps.STORE] # the first store is to lds and can be upcasted assert accs[0].dtype == stores[0].vin[-1].dtype == dtypes.float.vec(4) - assert stores[0].vin[0].uop == UOps.DEFINE_LOCAL + assert stores[0].vin[0].uop is UOps.DEFINE_LOCAL # the second store is to gds with no upcasts assert accs[1].dtype == stores[1].vin[-1].dtype == dtypes.float - assert stores[1].vin[0].uop == UOps.DEFINE_GLOBAL + assert stores[1].vin[0].uop is UOps.DEFINE_GLOBAL def test_zero_fold(self): a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize() @@ -139,7 +139,7 @@ class TestLinearizer(unittest.TestCase): k = Linearizer(*create_schedule([r.lazydata])[-1].ast) k.upcast() k.linearize() - num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU]) + num_ops = len([uop for uop in k.uops if uop.uop is UOps.ALU]) assert num_ops == 0, "more alu uops than needed" def test_constant_fold(self): @@ -157,14 +157,14 @@ class TestLinearizer(unittest.TestCase): a = Tensor([1, 2, 3], dtype=tensor_dtype).sum() k = Linearizer(*create_schedule([a.lazydata])[-1].ast) k.linearize() - local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC] + local = [uop for uop in k.uops if uop.uop is UOps.DEFINE_ACC] assert local[0].dtype == acc_dtype def test_arg_acc_dtype(self): def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType): k = Linearizer(*create_schedule([c.lazydata])[-1].ast) k.linearize() - local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC] + local = [uop for uop in k.uops if uop.uop is UOps.DEFINE_ACC] assert local[0].dtype == expected_dtype tests = ( @@ -194,8 +194,8 @@ class TestLinearizer(unittest.TestCase): k = Linearizer(realized_ast) k.apply_tensor_cores(1) k.linearize() - assert len([uop for uop in k.uops if uop.uop == UOps.WMMA]) == 1, "tensor core not triggered" - assert len([x for x in k.applied_opts if x.op == OptOps.TC]) == 1, "tensor core opt not included" + assert len([uop for uop in k.uops if uop.uop is UOps.WMMA]) == 1, "tensor core not triggered" + assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included" np_c = np_a @ np_b (tc_atol, tc_rtol) = (1e-2, 1e-3) if tc.dtype_out == dtypes.half else (5e-3, 1e-4) np.testing.assert_allclose(np_c, r.numpy(), atol=tc_atol, rtol=tc_rtol) @@ -213,7 +213,7 @@ class TestLinearizer(unittest.TestCase): sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps] assert len(sched) == 1 lin = Linearizer(*sched[0].ast) - assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse" + assert not any(u.uop is UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse" def test_assign_fold(self): a = Tensor.ones(4, 4).contiguous().realize() @@ -264,8 +264,8 @@ class TestLinearizer(unittest.TestCase): if if_op:=next((u for u in uops if u.uop is UOps.IF), None): uops = uops[:uops.index(if_op)] assert len(set([u.uop for u in uops if u.uop in {UOps.LOOP, UOps.SPECIAL}])) == 1, "has either specials or loops, not both" - assert len([u for u in uops if u.uop == UOps.PHI]) == 0, "PHI should have been simplified" - assert len([u for u in uops if u.arg == BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops" + assert len([u for u in uops if u.uop is UOps.PHI]) == 0, "PHI should have been simplified" + assert len([u for u in uops if u.arg is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops" helper(Tensor.arange(5.5, (3.5*300), 3.5)) helper(Tensor.arange(-1, -100, -5)) @@ -285,8 +285,8 @@ def helper_realized_ast(r:Tensor): class TestFloat4(unittest.TestCase): @staticmethod def count_float4(k): - return (len([uop for uop in k.uops if uop.uop == UOps.LOAD and uop.dtype == dtypes.float.vec(4)]), - len([uop for uop in k.uops if uop.uop == UOps.STORE and len(uop.vin) == 3 and uop.vin[2].dtype == dtypes.float.vec(4)])) + return (len([uop for uop in k.uops if uop.uop is UOps.LOAD and uop.dtype == dtypes.float.vec(4)]), + len([uop for uop in k.uops if uop.uop is UOps.STORE and len(uop.vin) == 3 and uop.vin[2].dtype == dtypes.float.vec(4)])) # TODO: express opts below as auto opts @@ -831,7 +831,7 @@ class TestLinearizerUOptimize(unittest.TestCase): # check that the float4 cast collapses store_vals = [u.vin[-1] for u in k.uops if u.uop is UOps.STORE] for val in store_vals: - assert val.dtype == dtypes.float.vec(4) and val.uop != UOps.CAST + assert val.dtype == dtypes.float.vec(4) and val.uop is not UOps.CAST @unittest.skipUnless(Device[Device.DEFAULT].compiler.compiler_opts.supports_float4, "device doesn't support float4") def test_grouped_store_values(self): @@ -843,7 +843,7 @@ class TestLinearizerUOptimize(unittest.TestCase): k.linearize() store_val = [u.vin[-1] for u in k.uops if u.uop is UOps.STORE][0] - assert store_val.dtype == dtypes.float.vec(4) and store_val.uop != UOps.CAST + assert store_val.dtype == dtypes.float.vec(4) and store_val.uop is not UOps.CAST def test_grouped_store_locals_and_globals(self): if not Device[Device.DEFAULT].compiler.compiler_opts.has_local or not Device[Device.DEFAULT].compiler.compiler_opts.has_shared or \ @@ -865,7 +865,7 @@ class TestLinearizerUOptimize(unittest.TestCase): # check that the float4 cast collapses for all stores for store in local_stores+global_stores: - assert store.vin[-1].dtype == dtypes.float.vec(2) and store.vin[-1].uop != UOps.CAST + assert store.vin[-1].dtype == dtypes.float.vec(2) and store.vin[-1].uop is not UOps.CAST # check the children's vins assert barrier.vin == tuple(local_stores) assert len([u for u in k.uops if u.uop is UOps.IF and u.vin[-1] == barrier]) == 1 @@ -881,11 +881,11 @@ class TestLinearizerUOptimize(unittest.TestCase): k.hand_coded_optimizations() k.linearize() - stores = [u for u in k.uops if u.uop == UOps.STORE] + stores = [u for u in k.uops if u.uop is UOps.STORE] # the float4 value stores directly in lds and we skip upcast assert stores[0].vin[-1].dtype == dtypes.float.vec(4) - assert stores[0].vin[-1].uop != UOps.CAST + assert stores[0].vin[-1].uop is not UOps.CAST # the global store doesn't change assert stores[1].vin[-1].dtype == dtypes.float @@ -903,7 +903,7 @@ class TestLinearizerUOptimize(unittest.TestCase): for opt in opts: k.apply_opt(opt) k.linearize() - out = [u for u in k.uops if u.uop == UOps.STORE][0] + out = [u for u in k.uops if u.uop is UOps.STORE][0] assert out.vin[-1].uop is UOps.CAST and out.vin[-1].dtype == dtypes.float.vec(4) def test_skip_unmatching_upcasts_with_gep(self): @@ -918,7 +918,7 @@ class TestLinearizerUOptimize(unittest.TestCase): for opt in opts: k.apply_opt(opt) k.linearize() - out = [u for u in k.uops if u.uop == UOps.STORE][0] + out = [u for u in k.uops if u.uop is UOps.STORE][0] assert out.vin[-1].uop is UOps.CAST and out.vin[-1].dtype == dtypes.float.vec(2) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 4e58536889..1c45112213 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -332,19 +332,19 @@ class Kernel: # ******************** high level optimizers ******************** def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool: - if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op == ReduceOps.SUM and self.opts.device in tensor_cores: + if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op is ReduceOps.SUM and self.opts.device in tensor_cores: for tc in tensor_cores[self.opts.device]: has_cast = tc.dtype_in != tc.dtype_out - if has_cast and not(self.reduceop.src[0].op == UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue + if has_cast and not(self.reduceop.src[0].op is UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0] - if mul_op.op != BinaryOps.MUL: continue + if mul_op.op is not BinaryOps.MUL: continue def buf_index(src: LazyOp) -> Optional[int]: # TODO: apply tc even if the sources are not from LOAD - if src.op == BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg)) + if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg)) try: - if opt_level >= 1 and src.op == UnaryOps.CAST and src.arg[0] == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg)) + if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg[0] == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg)) except ValueError: return None return None if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: continue @@ -405,7 +405,7 @@ class Kernel: def apply_opt(self, opt:Opt, append_opt:bool=True): check(not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals") - if opt.op == OptOps.TC: + if opt.op is OptOps.TC: check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt") check((use_tensor_cores:=getenv("TC", 1)) == 2 or self.opts.has_tensor_cores, "must have tensor cores or TC=2") @@ -414,34 +414,34 @@ class Kernel: return if opt.axis is not None: - axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+self.group_for_reduces if opt.op in [OptOps.GROUP, OptOps.GROUPTOP] else 0)) # noqa: E501 + axis = opt.axis + (self.first_reduce if opt.op is OptOps.UNROLL else (self.first_reduce+self.group_for_reduces if opt.op in [OptOps.GROUP, OptOps.GROUPTOP] else 0)) # noqa: E501 else: axis = -1 check(axis < len(self.full_shape), "invalid axis") if opt.amt is not None: amt = opt.amt if opt.amt != 0 else self.full_shape[axis] check(isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless") - if opt.op != OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift") + if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift") else: amt = -1 - if self.reduceop and (opt.op in [OptOps.GROUP, OptOps.GROUPTOP] or (self.group_for_reduces and opt.op not in [OptOps.NOLOCALS, OptOps.PADTO])): + if self.reduceop and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})): acc_sz = dt.base.itemsize if isinstance((dt:=get_lazyop_info(self.reduceop).dtype), ImageDType) else dt.itemsize upcast_sz = prod(self.full_shape[self.shape_len-self.upcasted:]) local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces]) check(amt*acc_sz*upcast_sz*local_sz <= self.opts.shared_max, "exceeds maximum shared memory size") - if opt.op == OptOps.LOCAL: # cyan + if opt.op is OptOps.LOCAL: # cyan check(self.opts.has_local, "target does not support local") check(axis < self.global_dims, "local is for globals") self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims) self.local_dims += 1 - elif opt.op in [OptOps.GROUP, OptOps.GROUPTOP]: # green + elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem") check(axis >= self.first_reduce + self.group_for_reduces and axis < self.shape_len-self.upcasted, "must be reduce axis to group") check(not self.tensor_core, "can't group with tensor cores") - self.shift_to(axis, amt, top=(opt.op==OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces) + self.shift_to(axis, amt, top=(opt.op is OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces) self.group_for_reduces += 1 - elif opt.op == OptOps.UNROLL: # purple + elif opt.op is OptOps.UNROLL: # purple check(axis < self.shape_len-self.upcasted, "can't upcasted already upcasted") check(amt <= 32, "don't unroll more than 32") # TODO: fix upcast_count to put purples before yellows. broken because of METAL tensor cores @@ -451,13 +451,13 @@ class Kernel: if self.full_shape[axis] == amt and axis < self.first_reduce+self.group_for_reduces: self.group_for_reduces -= 1 # fully unrolling a GROUP self.shift_to(axis, amt, insert_before=None) self.upcast() - elif opt.op == OptOps.UPCAST: # yellow + elif opt.op is OptOps.UPCAST: # yellow check(axis < self.first_reduce, "upcast is for non-reduce") check(not(self.tensor_core and axis >= self.first_reduce-len(self.tensor_core.threads)), "can't upcast TC locals") check(amt <= 8, "don't upcast more than 8") self.shift_to(axis, amt, insert_before=None) self.upcast() - elif opt.op == OptOps.UPCASTMID: # white + elif opt.op is OptOps.UPCASTMID: # white check(self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501 axes = self.sts[0].unit_stride_axes() check(len(axes) == 1, f"wrong number of stride 1 axis : {axes}") @@ -465,11 +465,11 @@ class Kernel: check(amt == 4, "don't upcast mid anything but 4") self.shift_to(axis, amt, insert_before=self.first_reduce + self.group_for_reduces) self.group_for_reduces += 1 - elif opt.op == OptOps.NOLOCALS: + elif opt.op is OptOps.NOLOCALS: check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals") check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals") self.dont_use_locals = True - elif opt.op == OptOps.PADTO: + elif opt.op is OptOps.PADTO: check(not self.vars, "does not work with symbolic shape") check(axis < self.first_reduce, "cannot pad a reduce axis") padded = False @@ -498,8 +498,8 @@ class Kernel: # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4) if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \ - self.reduceop and self.reduceop.op == ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \ - (mulop:=self.reduceop.src[0]).op == BinaryOps.MUL and mulop.src[0].op == BufferOps.LOAD and mulop.src[1].op == BufferOps.LOAD: + self.reduceop and self.reduceop.op is ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \ + (mulop:=self.reduceop.src[0]).op is BinaryOps.MUL and mulop.src[0].op is BufferOps.LOAD and mulop.src[1].op is BufferOps.LOAD: st0, st1 = self.sts[self.bufs.index(mulop.src[0].arg)], self.sts[self.bufs.index(mulop.src[1].arg)] strides0, strides1 = st0.real_strides(), st1.real_strides() def has_expanded_axis(shape, strides): return any(s > 1 and st == 0 for s,st in zip(shape,strides)) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 7f5b0baea4..72b44912b8 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -53,8 +53,8 @@ class Linearizer(Kernel): info = get_lazyop_info(reduceop) assert all(0 <= x < len(info.shape) for x in reduceop.arg), "arg axis out of range" dtype = info.dtype - if reduceop.op == ReduceOps.SUM: return 0.0 if dtypes.is_float(dtype) else 0 - elif reduceop.op == ReduceOps.MAX: + if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(dtype) else 0 + elif reduceop.op is ReduceOps.MAX: if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1) return -math.inf if dtypes.is_float(dtype) else False @@ -410,7 +410,7 @@ class Linearizer(Kernel): if cache is None: cache = {} if x in cache: return cache[x] if x.op in BufferOps: return loaded_buffers[x.arg] - if x.op == UnaryOps.CAST: return [self.uops.add(UOps.BITCAST if x.arg[1] else UOps.CAST, self.get_base_dtype(x.arg[0]), (u,), x.arg[0]) \ + if x.op is UnaryOps.CAST: return [self.uops.add(UOps.BITCAST if x.arg[1] else UOps.CAST, self.get_base_dtype(x.arg[0]), (u,), x.arg[0]) \ for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)] if x.op in ReduceOps and not do_reduce: assert offs is None, "not available if we aren't doing reduce" @@ -428,6 +428,6 @@ class Linearizer(Kernel): if input_acc[off] != acc[off]: acc[off] = self.uops.add(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]) + tuple(loop_ctx)) else: - ret = [self.uops.add(UOps.ALU, dtypes.bool if x.op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else val[-1].dtype, val, x.op) for val in zip(*values)] + ret = [self.uops.add(UOps.ALU, dtypes.bool if x.op in {BinaryOps.CMPLT, BinaryOps.CMPEQ} else val[-1].dtype, val, x.op) for val in zip(*values)] cache[x] = ret return ret diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index f36ed79d76..d1da750417 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -57,8 +57,8 @@ def exec_alu(arg, dtype, p): return truncate[dtype](python_alu[arg](*p)) def uop_alu_resolve(u:UOp) -> sint: if u.uop is UOps.CONST: return u.arg elif u.uop is UOps.DEFINE_VAR: return u.arg - elif u.uop is UOps.ALU and u.arg == BinaryOps.MUL: return uop_alu_resolve(u.vin[0]) * uop_alu_resolve(u.vin[1]) - elif u.uop is UOps.ALU and u.arg == BinaryOps.ADD: return uop_alu_resolve(u.vin[0]) + uop_alu_resolve(u.vin[1]) + elif u.uop is UOps.ALU and u.arg is BinaryOps.MUL: return uop_alu_resolve(u.vin[0]) * uop_alu_resolve(u.vin[1]) + elif u.uop is UOps.ALU and u.arg is BinaryOps.ADD: return uop_alu_resolve(u.vin[0]) + uop_alu_resolve(u.vin[1]) else: raise RuntimeError(f"ALU resolve fail @ {u.uop}") def _match(uop:UOp, pattern:Dict[str, Any], store:Dict[str, UOp]) -> bool: @@ -198,7 +198,7 @@ class UOpGraph: while ssize != len(deps): ssize = len(deps) for u in self.uops: - if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])): + if len(deps.intersection([x for x in u.vin if x.uop is not UOps.PHI])): deps.add(u) return deps @@ -237,16 +237,16 @@ class UOpGraph: def simplify_phi_loops(self, get_recursive_parents): def alu_opposite(arg, x, y): - if arg == BinaryOps.ADD: return x - y - elif arg == BinaryOps.MUL: return Node.__floordiv__(x, y, False) + if arg is BinaryOps.ADD: return x - y + elif arg is BinaryOps.MUL: return Node.__floordiv__(x, y, False) else: raise RuntimeError("unhandled alu") def to_symbolic(u: UOp): - if u.uop == UOps.CONST: return NumNode(int(u.arg)) + if u.uop is UOps.CONST: return NumNode(int(u.arg)) elif u.uop in {UOps.LOOP, UOps.SPECIAL}: if u not in seen_vars: seen_vars[u] = u.arg[1] if u.uop is UOps.SPECIAL else "loop{}".format(len(seen_vars)) return Variable(seen_vars[u], u.vin[0].arg, u.vin[1].arg-1) if u.uop is UOps.LOOP else Variable(seen_vars[u], 0, u.arg[2]-1) - elif u.uop == UOps.ALU and u.arg == BinaryOps.ADD: return to_symbolic(u.vin[0]) + to_symbolic(u.vin[1]) - elif u.uop == UOps.ALU and u.arg == BinaryOps.MUL: return to_symbolic(u.vin[0]) * to_symbolic(u.vin[1]) + elif u.uop is UOps.ALU and u.arg is BinaryOps.ADD: return to_symbolic(u.vin[0]) + to_symbolic(u.vin[1]) + elif u.uop is UOps.ALU and u.arg is BinaryOps.MUL: return to_symbolic(u.vin[0]) * to_symbolic(u.vin[1]) else: raise RuntimeError("unhandled op: {}".format(u)) def loop_factor(with_loop: UOp, factored: Node, loop_op, round_up=False): if with_loop == loop_op: return factored diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 9695e96d16..508ef9a5cc 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -100,7 +100,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffe if buf.forced_realize: realizes.add(buf) allbufs[buf] = None if buf.op in LoadOps: realizes.add(buf.base) - if buf.op == LoadOps.COPY: + if buf.op is LoadOps.COPY: assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig" realizes.add(buf.srcs[0].base) for x in buf.srcs: diff --git a/tinygrad/features/graph.py b/tinygrad/features/graph.py index 330466ccaf..1fe8a367f1 100644 --- a/tinygrad/features/graph.py +++ b/tinygrad/features/graph.py @@ -49,7 +49,7 @@ top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", Bin TernaryOps: "#c0c0c0", BufferOps: '#a0a0ff'} def log_lazybuffer(lb:'LazyBuffer', scheduled=False): init_graph() - if lb.base.realized is None and lb.base.op == LoadOps.CONST: return + if lb.base.realized is None and lb.base.op is LoadOps.CONST: return if lb.base != lb: offset = lb.st.expr_idxs([NumNode(0)] * len(lb.st.shape))[0] label = f"{lb.st.shape}\n{lb.st.real_strides()}" + (f"\n{offset}" if offset != 0 else "") @@ -60,7 +60,7 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False): label_append = [] for idx,x in enumerate(lb.srcs): if nm(x) not in G.nodes: log_lazybuffer(x) - if x.base.realized is None and x.base.op == LoadOps.CONST: + if x.base.realized is None and x.base.op is LoadOps.CONST: label_append.append(f"\nCONST{idx} {x.base.arg}") else: G.add_edge(nm(x), nm(lb), color='#a0a0a0') diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index f76971a1a3..bbdf6e31cf 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -92,8 +92,8 @@ class LazyBuffer: new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,) return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, UnaryOps.CAST, (dtype, bitcast), (self,)) - def is_unrealized_const(self): return not self.base.realized and self.base.op is LoadOps.CONST - def is_unrealized_contiguous_const(self): return self.base == self and not self.base.realized and self.op is LoadOps.CONST + def is_unrealized_const(self): return self.base.realized is None and self.base.op is LoadOps.CONST + def is_unrealized_contiguous_const(self): return self.base == self and self.base.realized is None and self.op is LoadOps.CONST def _copy(self, device:str) -> LazyBuffer: if (dstart:=self.device.split(":")[0]) in {"EXT", "DISK"} or (dstart in {"HSA", "CUDA"} and device.split(":")[0] == dstart): diff --git a/tinygrad/ops.py b/tinygrad/ops.py index d2263d8d4b..5cecbbc52e 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -89,7 +89,7 @@ InterpretedFlopCounter: Dict[Op, Callable] = { BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}), BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, arg.dtype, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize*arg.st.real_size()}), # noqa: E501 UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops - **{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op != UnaryOps.CAST}, # noqa: E501 + **{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op is not UnaryOps.CAST}, # noqa: E501 **{op:lambda self,y,op=op: FlopCounter(self.shape, dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else self.dtype, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501 **{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501 TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, y.dtype, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501 diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index 65d0b11e8a..2a224954af 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -66,10 +66,10 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: def ptr_ar(root): assert root.arg in {'.shared', '.global', None} - if root.arg is None: root.arg = '.shared' if root.vin[0].uop == UOps.DEFINE_LOCAL else '.global' # move this to the argL + if root.arg is None: root.arg = '.shared' if root.vin[0].uop is UOps.DEFINE_LOCAL else '.global' # move this to the argL val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=root.vin[0].dtype.itemsize, insert_before=uops.uops.index(root)) ptr = uops.add(UOps.ALU, dtypes.int, (root.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root)) - if ptr.uop == UOps.CONST: root.vin = (root.vin[0], ptr) + root.vin[2:] + if ptr.uop is UOps.CONST: root.vin = (root.vin[0], ptr) + root.vin[2:] else: zero = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0, cachable=False, insert_before=uops.uops.index(root)) bptr = uops.add(UOps.CAST, dtypes.uint64, (ptr,), insert_before=uops.uops.index(root)) @@ -133,17 +133,17 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: for u in uops: uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg - if uop == UOps.IF: + if uop is UOps.IF: assert vin[0].dtype is not None kk(*lang.render_bra(lb:=ssa_label(u, 'if'), cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:") - elif uop == UOps.BARRIER and lang.barrier: kk(lang.barrier) - elif uop == UOps.ENDLOOP: + elif uop is UOps.BARRIER and lang.barrier: kk(lang.barrier) + elif uop is UOps.ENDLOOP: kk(lang.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, lang.types[dtypes.int]), lang.asm_for_op[BinaryOps.CMPLT](pred:=ssa(None, "pred", "pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, lang.types[dtypes.int])) kk(*lang.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:") - elif uop == UOps.ENDIF: + elif uop is UOps.ENDIF: kk(f"{r_label[vin[0]]}:") - elif uop == UOps.STORE: + elif uop is UOps.STORE: assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None if vin[2].dtype.count > 1: kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \ @@ -152,8 +152,8 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: kk(*lang.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=u.arg, offset=vin[1].arg)) else: assert dtype is not None, f"None dtype for uop {uop}" - if uop == UOps.LOOP: kk(*lang.render_loop(ssa(u, 'ridx'), r[vin[0]], ssa_label(u, 'loop'))) - elif uop == UOps.ALU: + if uop is UOps.LOOP: kk(*lang.render_loop(ssa(u, 'ridx'), r[vin[0]], ssa_label(u, 'loop'))) + elif uop is UOps.ALU: assert vin[0].dtype is not None operands = [r[x] for x in vin] lab = ssa(u, "alu") @@ -163,28 +163,28 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: for i, op in enumerate(operands): operands[i] = ssa(None, "alu_cast", lang.types[dtype]) kk(*lang.render_cast(operands[i], op, dtype, dtypes.half)) # type: ignore - if args == BinaryOps.CMPLT or args == BinaryOps.CMPEQ: + if args is BinaryOps.CMPLT or args is BinaryOps.CMPEQ: # pass in the other dtype here kk(lang.asm_for_op[args](lab, *operands, vin[0].dtype, lang.types[vin[0].dtype])) else: kk(lang.asm_for_op[args](lab, *operands, dtype, lang.types[dtype])) if needs_upcast: kk(*lang.render_cast(out_lab, lab, dtypes.half, dtype)) - elif uop == UOps.DEFINE_ACC: + elif uop is UOps.DEFINE_ACC: if dtype.count > 1: r[u] = [ssa(None, 'acc', lang.types[dtype.scalar()]) for _ in range(dtype.count)] for uu in r[u]: kk(f"mov.b{lang.types[dtype.scalar()][1:]} {uu}, {const(args, dtype.scalar())};") else: kk(f"mov.b{lang.types[dtype][1:]} {ssa(u, 'acc')}, {const(args, dtype)};") - elif uop == UOps.SPECIAL: + elif uop is UOps.SPECIAL: assert args[1][0] != "i", "idx not supported" kk(f"mov.u32 %{args[1]}, {(lang.gid if args[1][0] == 'g' else lang.lid)[args[0]]};") r[u] = "%" + args[1] kernel = [f".reg .u32 %{args[1]};"] + kernel - elif uop == UOps.CONST: + elif uop is UOps.CONST: if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)] else: r[u] = const(args, dtype, mov=True) - elif uop == UOps.GEP: r[u] = r[vin[0]][u.arg] - elif uop == UOps.LOAD: + elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg] + elif uop is UOps.LOAD: assert vin[1].dtype is not None if dtype.count > 1: r[u] = [ssa(None, 'val', lang.types[dtype.scalar()]) for _ in range(dtype.count)] @@ -195,14 +195,14 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str: else: kk(*lang.render_load(r[vin[0]], ssa(u, 'val'), dtype, gate=r[vin[2]] if len(vin) > 3 else None, alt=r[vin[3]] if len(vin) > 3 else None, ss=u.arg, offset=vin[1].arg)) - elif uop == UOps.PHI: + elif uop is UOps.PHI: kk(f"mov.b{lang.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};") r[u] = r[vin[0]] elif uop in {UOps.CAST, UOps.BITCAST}: assert vin[0].dtype is not None if dtype.count>1: r[u] = [r[x] for x in vin] # type: ignore else: cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u) - elif uop == UOps.DEFINE_LOCAL: + elif uop is UOps.DEFINE_LOCAL: # TODO: we should sum these, and fetch 0xC000 from somewhere assert args[1]*dtype.itemsize <= 0xC000, "too large local" kk(*lang.render_local(ssa(u, 'local', lang.types[dtypes.ulong]), args[0], args[1], dtype)) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 2a4d78d1d0..6fbd58124b 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -131,7 +131,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str val = lang.code_for_op[args](*operands, dtype) assert child_count[u] != 0, f"childless ALU op found {u}" # TODO: fix index rendering issue. fix clang nested max macro issue - if child_count[u] <= 1 and args != BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val + if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val else: kk(f"{lang.render_dtype(dtype)} {ssa(u,'alu')} = {val};") elif uop is UOps.SPECIAL: kk(f"int {args[1]} = {lang.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */") @@ -215,7 +215,7 @@ class MetalLanguage(CStyleLanguage): return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype) def render_kernel(self, function_name, kernel, bufs, uops, prefix=None): - prefix, wmma_args = ["#include ","using namespace metal;"], set([uop.arg for uop in uops if uop.uop == UOps.WMMA]) + prefix, wmma_args = ["#include ","using namespace metal;"], set([uop.arg for uop in uops if uop.uop is UOps.WMMA]) for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{ simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x; b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c); @@ -257,7 +257,7 @@ class CUDALanguage(CStyleLanguage): prefix += ["#include "] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x) for x in [4, 8]] # TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing - for arg in set([uop.arg for uop in uops if uop.uop == UOps.WMMA]): + for arg in set([uop.arg for uop in uops if uop.uop is UOps.WMMA]): fn, ti, to, ci, co = arg[0], dt_map[arg[2]][0], dt_map[arg[3]][0], dt_map[arg[2]][1], dt_map[arg[3]][1] prefix.append(f"""__device__ {to}4 __{fn}({ti}8 a, {ti}4 b, {to}4 c) {{ int *a_pk = (int *) (&a), *b_pk = (int *) (&b); asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}, {{ %4, %5, %6, %7 }}, {{ %8, %9 }}, {{ %0, %1, %2, %3 }};" @@ -338,7 +338,7 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { prefix += [_make_hip_dtype(*x) for x in vec_dts] - for arg in set([uop.arg for uop in uops if uop.uop == UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper + for arg in set([uop.arg for uop in uops if uop.uop is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32") else: prefix.append(f"static __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) { half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; } @@ -347,7 +347,7 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) { return super().render_kernel(function_name, kernel, bufs, uops, prefix) def get_kernel_modifier(self, uops:UOpGraph) -> str: - requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.uop == UOps.SPECIAL and u.arg[1][0] == "l") + requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.uop is UOps.SPECIAL and u.arg[1][0] == "l") # https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size # NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))" diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 79a3fb51fe..c07c4f8a73 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -109,7 +109,7 @@ def uops_to_llvm_ir(function_name:str, uops:UOpGraph) -> str: bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), loop_entry_bb, bb[-1].block) else: assert dtype is not None, f"None dtype for uop {uop}" - if uop == UOps.LOOP: + if uop is UOps.LOOP: bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}"))) bb[-2].branch(bb[-1].block) @@ -138,7 +138,7 @@ def uops_to_llvm_ir(function_name:str, uops:UOpGraph) -> str: lvars[u] = lvars[vin[1]] # PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC backward = vin[0] - while backward.uop == UOps.PHI: backward = backward.vin[0] + while backward.uop is UOps.PHI: backward = backward.vin[0] lvars[backward] = lvars[u] elif uop is UOps.ALU: lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else vin[0].dtype) diff --git a/tinygrad/runtime/ops_disk.py b/tinygrad/runtime/ops_disk.py index 9691aba5a3..8909b92dfc 100644 --- a/tinygrad/runtime/ops_disk.py +++ b/tinygrad/runtime/ops_disk.py @@ -48,17 +48,17 @@ class DiskRunner(JITRunner): skip_allocation = True def __init__(self, ast:LazyOp): # two ASTs are allowed here. - assert ast.op == BufferOps.STORE, "output of AST must be store" + assert ast.op is BufferOps.STORE, "output of AST must be store" assert ast.arg.st.contiguous, "shapetracker must be contiguous" # TODO: there shouldn't actually be casts here, bitcasts should fold into the load - if ast.src[0].op == UnaryOps.CAST: + if ast.src[0].op is UnaryOps.CAST: top_src = ast.src[0].src[0] assert ast.src[0].arg[1], "disk only supports bitcasts, not normal casts" self.new_dtype = ast.src[0].arg[0] else: top_src = ast.src[0] self.new_dtype = top_src.arg.dtype - assert top_src.op == BufferOps.LOAD, "top of AST must be load" + assert top_src.op is BufferOps.LOAD, "top of AST must be load" assert len(top_src.arg.st.views) == 1, "shapetracker must have 1 view" view = top_src.arg.st.views[0] assert view.mask is None, "view cannot have a mask"