mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
tuples only (#1769)
This commit is contained in:
@@ -113,26 +113,26 @@ class Linearizer(OptimizedKernel):
|
||||
if key not in self.load_cache:
|
||||
if acc is not None:
|
||||
assert valid.min == 1
|
||||
self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, [], this_const)
|
||||
self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, (), this_const)
|
||||
elif this_const is not None:
|
||||
self.load_cache[key] = self.const(this_const, localtype)
|
||||
if valid.min == 0 and valid.max == 1:
|
||||
valid_rendered = valid.render(self.render_ops, self)
|
||||
self.load_cache[key] = self.uop(UOps.ALU, localtype, [valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)], TernaryOps.WHERE, cachable=True)
|
||||
self.load_cache[key] = self.uop(UOps.ALU, localtype, (valid_rendered, self.load_cache[key], self.const(invalid_value, localtype)), TernaryOps.WHERE, cachable=True)
|
||||
else:
|
||||
buf_uop = self.buf_uops[i]
|
||||
assert buf_uop is not None, f"buffer {i} wasn't UOped"
|
||||
if isinstance(self.bufs[i].dtype, ImageDType):
|
||||
idx = to_image_idx(self.bufs[i].dtype.shape, idx, valid)
|
||||
rendered_idx = self.uop(UOps.CAST, dtypes._int2, [idx[0].render(self.render_ops, self), idx[1].render(self.render_ops, self)])
|
||||
rendered_idx = self.uop(UOps.CAST, dtypes._int2, (idx[0].render(self.render_ops, self), idx[1].render(self.render_ops, self)))
|
||||
else:
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
if valid.min == 0:
|
||||
valid_rendered = valid.render(self.render_ops, self)
|
||||
self.load_cache[key] = self.uop(UOps.LOAD, localtype, [buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype)])
|
||||
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx, valid_rendered, self.const(invalid_value, localtype)))
|
||||
else:
|
||||
self.load_cache[key] = self.uop(UOps.LOAD, localtype, [buf_uop, rendered_idx])
|
||||
ret.append(self.uop(UOps.GEP, dtypes.float32, [self.load_cache[key]], expanded_nodes[dim].index(_idx[dim])) if localtype != dtypes.float else self.load_cache[key])
|
||||
self.load_cache[key] = self.uop(UOps.LOAD, localtype, (buf_uop, rendered_idx))
|
||||
ret.append(self.uop(UOps.GEP, dtypes.float32, (self.load_cache[key],), expanded_nodes[dim].index(_idx[dim])) if localtype != dtypes.float else self.load_cache[key])
|
||||
return ret
|
||||
|
||||
def global_store(self, i:int, idxs:List[VariableOrNum], store:List[UOp]) -> None:
|
||||
@@ -156,17 +156,17 @@ class Linearizer(OptimizedKernel):
|
||||
idx, valid = self.sts[i].expr_idxs(k)
|
||||
assert idx.render() == ((idx//amt)*amt).render(), "float4 stores are always aligned"
|
||||
assert valid.min == 1, "stores are always valid"
|
||||
store_offset_new[k] = self.uop(UOps.CAST, dtypes._float4 if amt == 4 else dtypes._float2, out_tokens)
|
||||
store_offset_new[k] = self.uop(UOps.CAST, dtypes._float4 if amt == 4 else dtypes._float2, tuple(out_tokens))
|
||||
store_offset = store_offset_new
|
||||
|
||||
for idx, var in store_offset.items():
|
||||
idx, valid = self.sts[i].expr_idxs(idx)
|
||||
if isinstance(self.bufs[i].dtype, ImageDType):
|
||||
idx = to_image_idx(self.bufs[i].dtype.shape, idx, valid)
|
||||
rendered_idx = self.uop(UOps.CAST, dtypes._int2, [idx[0].render(self.render_ops, self), idx[1].render(self.render_ops, self)])
|
||||
rendered_idx = self.uop(UOps.CAST, dtypes._int2, tuple(x.render(self.render_ops, self) for x in idx))
|
||||
else:
|
||||
rendered_idx = idx.render(self.render_ops, self)
|
||||
self.uop(UOps.STORE, None, [buf_uop, rendered_idx, var])
|
||||
self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var))
|
||||
|
||||
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
|
||||
def linearize(self):
|
||||
@@ -187,22 +187,22 @@ class Linearizer(OptimizedKernel):
|
||||
# add global buffers
|
||||
arg_bufs = {}
|
||||
for buf,name in self.arg_bufs.items():
|
||||
arg_bufs[buf] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, [], (name, buf.dtype))
|
||||
arg_bufs[buf] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (name, buf.dtype))
|
||||
for i,b in enumerate(self.bufs):
|
||||
if b.realized in arg_bufs: self.buf_uops[i] = arg_bufs[b.realized]
|
||||
# add variables from symbolic shapes
|
||||
for var in sorted(set(v for buf in self.ast.buffers for v in buf.st.var_vals), key=lambda k: k.key):
|
||||
assert var.expr is not None
|
||||
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, [], (var.expr, dtypes._arg_int32))
|
||||
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32))
|
||||
# define local buffers
|
||||
for lb in self.local_alias.values():
|
||||
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), [], (lb.name, self.sts[self.bufs.index(lb)].size()))
|
||||
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size()))
|
||||
# add a local buffer for multistage reduce. # TODO: use local alias
|
||||
if self.group_for_reduce:
|
||||
# TODO: the strides of this can be controlled
|
||||
self.sts.append(ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))
|
||||
self.bufs.append(LocalBuffer("temp", self.sts[-1].size()))
|
||||
self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), [], ("temp", self.sts[-1].size())))
|
||||
self.buf_uops.append(self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ("temp", self.sts[-1].size())))
|
||||
|
||||
# print
|
||||
if DEBUG >= 3: self.printbufs()
|
||||
@@ -231,7 +231,7 @@ class Linearizer(OptimizedKernel):
|
||||
for x in xx[::-1]:
|
||||
if not isinstance(x, NumNode) and x.expr is not None:
|
||||
loop_uop = self.loop_uops[x.expr]
|
||||
if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, [loop_uop])
|
||||
if loop_uop.uop == UOps.LOOP: self.uop(UOps.END, None, (loop_uop,))
|
||||
|
||||
if self.opts.has_local:
|
||||
self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs][::-1]
|
||||
@@ -259,7 +259,7 @@ class Linearizer(OptimizedKernel):
|
||||
render_loop(reduce_idxs)
|
||||
|
||||
# barrier for fast GEMM
|
||||
if self.use_tensor_cores: self.uop(UOps.BARRIER, None, [], ())
|
||||
if self.use_tensor_cores: self.uop(UOps.BARRIER, None, ())
|
||||
|
||||
# compute local aliases
|
||||
# TODO: this is garbage code and should be at least moved elsewhere
|
||||
@@ -305,24 +305,24 @@ class Linearizer(OptimizedKernel):
|
||||
i = 0
|
||||
for y0,y1 in zip(locals_to_store[1][2][::2], locals_to_store[1][2][1::2]):
|
||||
for x0,x1 in zip(locals_to_store[0][2][::2], locals_to_store[0][2][1::2]):
|
||||
self.uop(UOps.WMMA, None, [x0, x1, y0, y1, acc[i], acc[i+1]], "METAL")
|
||||
self.uop(UOps.WMMA, None, (x0, x1, y0, y1, acc[i], acc[i+1]), "METAL")
|
||||
i += 2
|
||||
else:
|
||||
k = len(locals_to_store[1][2]) // 2
|
||||
for i in range(0, len(acc), 2):
|
||||
for y0,y1,x0,x1 in zip(locals_to_store[1][2][:k], locals_to_store[1][2][k:], locals_to_store[0][2][k*i:], locals_to_store[0][2][k*i+k:]):
|
||||
self.uop(UOps.WMMA, None, [x0, x1, y0, y1, acc[i], acc[i+1]], "METAL")
|
||||
self.uop(UOps.WMMA, None, (x0, x1, y0, y1, acc[i], acc[i+1]), "METAL")
|
||||
elif self.bufs[0].device == "HIP":
|
||||
i = 0
|
||||
for y in range(0, len(locals_to_store[1][2]), 0x10):
|
||||
for x in range(0, len(locals_to_store[0][2]), 0x10):
|
||||
self.uop(UOps.WMMA, None, acc[i:i+8]+locals_to_store[0][2][x:x+0x10]+locals_to_store[1][2][y:y+0x10], "HIP")
|
||||
self.uop(UOps.WMMA, None, tuple(acc[i:i+8]+locals_to_store[0][2][x:x+0x10]+locals_to_store[1][2][y:y+0x10]), "HIP")
|
||||
i += 8
|
||||
else:
|
||||
if locals_to_store:
|
||||
self.uop(UOps.BARRIER, None, [], ())
|
||||
self.uop(UOps.BARRIER, None, ())
|
||||
for i, idxs, ll in locals_to_store: self.global_store(i, idxs, ll)
|
||||
self.uop(UOps.BARRIER, None, [], ())
|
||||
self.uop(UOps.BARRIER, None, ())
|
||||
|
||||
# load earlybufs
|
||||
loaded_buffers.update({b:self.global_load(self.bufs.index(self.local_alias[i]) if i in self.local_alias else i, global_idxs+local_idxs+reduce_idxs+full_upcast_idxs) for i,b in enumerate(self.bufs[1:], start=1) if b in self.earlybufs})
|
||||
@@ -338,7 +338,7 @@ class Linearizer(OptimizedKernel):
|
||||
if self.group_for_reduce:
|
||||
fake_global_idxs = [x*0 for x in global_idxs]
|
||||
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
|
||||
self.uop(UOps.BARRIER, None, [], ())
|
||||
self.uop(UOps.BARRIER, None, ())
|
||||
end_loop(loop_local_idxs)
|
||||
|
||||
# local indexs are over, 0 them out
|
||||
@@ -398,14 +398,14 @@ class Linearizer(OptimizedKernel):
|
||||
|
||||
return self
|
||||
|
||||
def uop(self, uop:UOps, dtype:Optional[DType], vin:Union[Tuple[UOp, ...], List[UOp]], arg:Any=None, cachable=False) -> UOp:
|
||||
key = (uop, dtype, tuple(vin), arg)
|
||||
def uop(self, uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None, cachable=False) -> UOp:
|
||||
key = (uop, dtype, vin, arg)
|
||||
if uop == UOps.STORE and len(vin) == 2 and vin[0] == vin[1]: return vin[0] # self store is noop
|
||||
if uop == UOps.ALU:
|
||||
# rewrites. NOTE: the rewritten NEG op is still around...
|
||||
if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, [vin[0], vin[1].vin[0]], BinaryOps.SUB, cachable=cachable)
|
||||
if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable=cachable)
|
||||
# constant folding
|
||||
if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.uop(UOps.CONST, dtype, [], -vin[0].arg, cachable=True)
|
||||
if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.uop(UOps.CONST, dtype, (), -vin[0].arg, cachable=True)
|
||||
# zero folding
|
||||
for x in [0,1]:
|
||||
if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x]
|
||||
@@ -414,7 +414,7 @@ class Linearizer(OptimizedKernel):
|
||||
if arg == BinaryOps.SUB and vin[1].uop == UOps.CONST and vin[1].arg == 0.0: return vin[0]
|
||||
if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0]
|
||||
if cachable and key in self.saved_exprs: return self.saved_exprs[key]
|
||||
self.uops.append(UOp(uop, dtype, tuple(vin), arg, len(self.uops)))
|
||||
self.uops.append(UOp(uop, dtype, vin, arg, len(self.uops)))
|
||||
if DEBUG >= 4: print(self.uops[-1])
|
||||
if cachable: self.saved_exprs[key] = self.uops[-1]
|
||||
return self.uops[-1]
|
||||
@@ -431,9 +431,9 @@ class Linearizer(OptimizedKernel):
|
||||
values = [self.ast_parse(v, acc, loaded_buffers) for v in x.src]
|
||||
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
|
||||
if x.op in ops:
|
||||
ret = [(idx, self.uop(UOps.STORE, dtypes.float32, [val[-1], self.uop(UOps.ALU, dtypes.float32, list(val), ops[x.op])])) for idx, val in zip([[i] for i in range(len(values[0]))], zip(*values, acc))]
|
||||
ret = [(idx, self.uop(UOps.STORE, dtypes.float32, (val[-1], self.uop(UOps.ALU, dtypes.float32, val, ops[x.op])))) for idx, val in zip([[i] for i in range(len(values[0]))], zip(*values, acc))]
|
||||
else:
|
||||
ret = [(idx, self.uop(UOps.ALU, dtypes.float32, list(val), x.op, cachable=True)) for idx, val in zip([[i] for i in range(len(values[0]))], zip(*values))]
|
||||
ret = [(idx, self.uop(UOps.ALU, dtypes.float32, val, x.op, cachable=True)) for idx, val in zip([[i] for i in range(len(values[0]))], zip(*values))]
|
||||
ordered_ret: List[Optional[UOp]] = [None]*len(values[0])
|
||||
# scatter
|
||||
for i,j in ret:
|
||||
|
||||
Reference in New Issue
Block a user