From 10305bfc0a0b59b74e34bf2ac6f47de956ae7622 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 4 Sep 2023 16:35:11 -0700 Subject: [PATCH] tuples only (#1769) --- tinygrad/codegen/linearizer.py | 56 +++++++++++++++++----------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index afff43bd29..1bb01290d2 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -113,26 +113,26 @@ class Linearizer(OptimizedKernel): if key not in self.load_cache: if acc is not None: assert valid.min == 1 - self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, [], this_const) + self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, (), this_const) elif this_const is not None: self.load_cache[key] = self.const(this_const, localtype) if valid.min == 0 and valid.max == 1: valid_rendered = valid.render(self.render_ops, self) - self.load_cache[key] = self.uop(UOps.ALU, localtype, [valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)], TernaryOps.WHERE, cachable=True) + self.load_cache[key] = self.uop(UOps.ALU, localtype, (valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)), TernaryOps.WHERE, cachable=True) else: buf_uop = self.buf_uops[i] assert buf_uop is not None, f"buffer {i} wasn't UOped" if isinstance(self.bufs[i].dtype, ImageDType): idx = to_image_idx(self.bufs[i].dtype.shape, idx, valid) - rendered_idx = self.uop(UOps.CAST, dtypes._int2, [idx[0].render(self.render_ops, self), idx[1].render(self.render_ops, self)]) + rendered_idx = self.uop(UOps.CAST, dtypes._int2, (idx[0].render(self.render_ops, self), idx[1].render(self.render_ops, self))) else: rendered_idx = idx.render(self.render_ops, self) if valid.min == 0: valid_rendered = valid.render(self.render_ops, self) - self.load_cache[key] = self.uop(UOps.LOAD, localtype, [buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype)]) + self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype))) else: - self.load_cache[key] = self.uop(UOps.LOAD, localtype, [buf_uop, rendered_idx]) - ret.append(self.uop(UOps.GEP, dtypes.float32, [self.load_cache[key]], expanded_nodes[dim].index(_idx[dim])) if localtype != dtypes.float else self.load_cache[key]) + self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx)) + ret.append(self.uop(UOps.GEP, dtypes.float32, (self.load_cache[key],), expanded_nodes[dim].index(_idx[dim])) if localtype != dtypes.float else self.load_cache[key]) return ret def global_store(self, i:int, idxs:List[VariableOrNum], store:List[UOp]) -> None: @@ -156,17 +156,17 @@ class Linearizer(OptimizedKernel): idx, valid = self.sts[i].expr_idxs(k) assert idx.render() == ((idx//amt)*amt).render(), "float4 stores are always aligned" assert valid.min == 1, "stores are always valid" - store_offset_new[k] = self.uop(UOps.CAST, dtypes._float4 if amt == 4 else dtypes._float2, out_tokens) + store_offset_new[k] = self.uop(UOps.CAST, dtypes._float4 if amt == 4 else dtypes._float2, tuple(out_tokens)) store_offset = store_offset_new for idx, var in store_offset.items(): idx, valid = self.sts[i].expr_idxs(idx) if isinstance(self.bufs[i].dtype, ImageDType): idx = to_image_idx(self.bufs[i].dtype.shape, idx, valid) - rendered_idx = self.uop(UOps.CAST, dtypes._int2, [idx[0].render(self.render_ops, self), idx[1].render(self.render_ops, self)]) + rendered_idx = self.uop(UOps.CAST, dtypes._int2, tuple(x.render(self.render_ops, self) for x in idx)) else: rendered_idx = idx.render(self.render_ops, self) - self.uop(UOps.STORE, None, [buf_uop, rendered_idx, var]) + self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var)) kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int) def linearize(self): @@ -187,22 +187,22 @@ class Linearizer(OptimizedKernel): # add global buffers arg_bufs = {} for buf,name in self.arg_bufs.items(): - arg_bufs[buf] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, [], (name, buf.dtype)) + arg_bufs[buf] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (name, buf.dtype)) for i,b in enumerate(self.bufs): if b.realized in arg_bufs: self.buf_uops[i] = arg_bufs[b.realized] # add variables from symbolic shapes for var in sorted(set(v for buf in self.ast.buffers for v in buf.st.var_vals), key=lambda k: k.key): assert var.expr is not None - self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, [], (var.expr, dtypes._arg_int32)) + self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32)) # define local buffers for lb in self.local_alias.values(): - self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), [], (lb.name, self.sts[self.bufs.index(lb)].size())) + self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size())) # add a local buffer for multistage reduce. # TODO: use local alias if self.group_for_reduce: # TODO: the strides of this can be controlled self.sts.append(ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) self.bufs.append(LocalBuffer("temp", self.sts[-1].size())) - self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), [], ("temp", self.sts[-1].size()))) + self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ("temp", self.sts[-1].size()))) # print if DEBUG >= 3: self.printbufs() @@ -231,7 +231,7 @@ class Linearizer(OptimizedKernel): for x in xx[::-1]: if not isinstance(x, NumNode) and x.expr is not None: loop_uop = self.loop_uops[x.expr] - if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, [loop_uop]) + if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, (loop_uop,)) if self.opts.has_local: self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs][::-1] @@ -259,7 +259,7 @@ class Linearizer(OptimizedKernel): render_loop(reduce_idxs) # barrier for fast GEMM - if self.use_tensor_cores: self.uop(UOps.BARRIER, None, [], ()) + if self.use_tensor_cores: self.uop(UOps.BARRIER, None, ()) # compute local aliases # TODO: this is garbage code and should be at least moved elsewhere @@ -305,24 +305,24 @@ class Linearizer(OptimizedKernel): i = 0 for y0,y1 in zip(locals_to_store[1][2][::2], locals_to_store[1][2][1::2]): for x0,x1 in zip(locals_to_store[0][2][::2], locals_to_store[0][2][1::2]): - self.uop(UOps.WMMA, None, [x0, x1, y0, y1, acc[i], acc[i+1]], "METAL") + self.uop(UOps.WMMA, None, (x0, x1, y0, y1, acc[i], acc[i+1]), "METAL") i += 2 else: k = len(locals_to_store[1][2]) // 2 for i in range(0, len(acc), 2): for y0,y1,x0,x1 in zip(locals_to_store[1][2][:k], locals_to_store[1][2][k:], locals_to_store[0][2][k*i:], locals_to_store[0][2][k*i+k:]): - self.uop(UOps.WMMA, None, [x0, x1, y0, y1, acc[i], acc[i+1]], "METAL") + self.uop(UOps.WMMA, None, (x0, x1, y0, y1, acc[i], acc[i+1]), "METAL") elif self.bufs[0].device == "HIP": i = 0 for y in range(0, len(locals_to_store[1][2]), 0x10): for x in range(0, len(locals_to_store[0][2]), 0x10): - self.uop(UOps.WMMA, None, acc[i:i+8]+locals_to_store[0][2][x:x+0x10]+locals_to_store[1][2][y:y+0x10], "HIP") + self.uop(UOps.WMMA, None, tuple(acc[i:i+8]+locals_to_store[0][2][x:x+0x10]+locals_to_store[1][2][y:y+0x10]), "HIP") i += 8 else: if locals_to_store: - self.uop(UOps.BARRIER, None, [], ()) + self.uop(UOps.BARRIER, None, ()) for i, idxs, ll in locals_to_store: self.global_store(i, idxs, ll) - self.uop(UOps.BARRIER, None, [], ()) + self.uop(UOps.BARRIER, None, ()) # load earlybufs loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs}) @@ -338,7 +338,7 @@ class Linearizer(OptimizedKernel): if self.group_for_reduce: fake_global_idxs = [x*0 for x in global_idxs] self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators - self.uop(UOps.BARRIER, None, [], ()) + self.uop(UOps.BARRIER, None, ()) end_loop(loop_local_idxs) # local indexs are over, 0 them out @@ -398,14 +398,14 @@ class Linearizer(OptimizedKernel): return self - def uop(self, uop:UOps, dtype:Optional[DType], vin:Union[Tuple[UOp, ...], List[UOp]], arg:Any=None, cachable=False) -> UOp: - key = (uop, dtype, tuple(vin), arg) + def uop(self, uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None, cachable=False) -> UOp: + key = (uop, dtype, vin, arg) if uop == UOps.STORE and len(vin) == 2 and vin[0] == vin[1]: return vin[0] # self store is noop if uop == UOps.ALU: # rewrites. NOTE: the rewritten NEG op is still around... - if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, [vin[0], vin[1].vin[0]], BinaryOps.SUB, cachable=cachable) + if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable=cachable) # constant folding - if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.uop(UOps.CONST, dtype, [], -vin[0].arg, cachable=True) + if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.uop(UOps.CONST, dtype, (), -vin[0].arg, cachable=True) # zero folding for x in [0,1]: if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x] @@ -414,7 +414,7 @@ class Linearizer(OptimizedKernel): if arg == BinaryOps.SUB and vin[1].uop == UOps.CONST and vin[1].arg == 0.0: return vin[0] if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0] if cachable and key in self.saved_exprs: return self.saved_exprs[key] - self.uops.append(UOp(uop, dtype, tuple(vin), arg, len(self.uops))) + self.uops.append(UOp(uop, dtype, vin, arg, len(self.uops))) if DEBUG >= 4: print(self.uops[-1]) if cachable: self.saved_exprs[key] = self.uops[-1] return self.uops[-1] @@ -431,9 +431,9 @@ class Linearizer(OptimizedKernel): values = [self.ast_parse(v, acc, loaded_buffers) for v in x.src] ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC} if x.op in ops: - ret = [(idx, self.uop(UOps.STORE, dtypes.float32, [val[-1], self.uop(UOps.ALU, dtypes.float32, list(val), ops[x.op])])) for idx, val in zip([[i] for i in range(len(values[0]))], zip(*values, acc))] + ret = [(idx, self.uop(UOps.STORE, dtypes.float32, (val[-1], self.uop(UOps.ALU, dtypes.float32, val, ops[x.op])))) for idx, val in zip([[i] for i in range(len(values[0]))], zip(*values, acc))] else: - ret = [(idx, self.uop(UOps.ALU, dtypes.float32, list(val), x.op, cachable=True)) for idx, val in zip([[i] for i in range(len(values[0]))], zip(*values))] + ret = [(idx, self.uop(UOps.ALU, dtypes.float32, val, x.op, cachable=True)) for idx, val in zip([[i] for i in range(len(values[0]))], zip(*values))] ordered_ret: List[Optional[UOp]] = [None]*len(values[0]) # scatter for i,j in ret: