mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-11 15:15:13 -05:00
use is to compare with enum (#3993)
* use is to compare with enum currently it's mixed between `==` and `is`, moved all to `is` * more
This commit is contained in:
@@ -332,19 +332,19 @@ class Kernel:
|
||||
# ******************** high level optimizers ********************
|
||||
|
||||
def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
|
||||
if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op == ReduceOps.SUM and self.opts.device in tensor_cores:
|
||||
if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op is ReduceOps.SUM and self.opts.device in tensor_cores:
|
||||
for tc in tensor_cores[self.opts.device]:
|
||||
has_cast = tc.dtype_in != tc.dtype_out
|
||||
if has_cast and not(self.reduceop.src[0].op == UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue
|
||||
if has_cast and not(self.reduceop.src[0].op is UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue
|
||||
|
||||
mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
|
||||
if mul_op.op != BinaryOps.MUL: continue
|
||||
if mul_op.op is not BinaryOps.MUL: continue
|
||||
|
||||
def buf_index(src: LazyOp) -> Optional[int]:
|
||||
# TODO: apply tc even if the sources are not from LOAD
|
||||
if src.op == BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
|
||||
if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
|
||||
try:
|
||||
if opt_level >= 1 and src.op == UnaryOps.CAST and src.arg[0] == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
|
||||
if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg[0] == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
|
||||
except ValueError: return None
|
||||
return None
|
||||
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: continue
|
||||
@@ -405,7 +405,7 @@ class Kernel:
|
||||
def apply_opt(self, opt:Opt, append_opt:bool=True):
|
||||
check(not self.dont_use_locals or opt.op not in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP, OptOps.UPCASTMID}, "not using locals")
|
||||
|
||||
if opt.op == OptOps.TC:
|
||||
if opt.op is OptOps.TC:
|
||||
check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine
|
||||
check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt")
|
||||
check((use_tensor_cores:=getenv("TC", 1)) == 2 or self.opts.has_tensor_cores, "must have tensor cores or TC=2")
|
||||
@@ -414,34 +414,34 @@ class Kernel:
|
||||
return
|
||||
|
||||
if opt.axis is not None:
|
||||
axis = opt.axis + (self.first_reduce if opt.op == OptOps.UNROLL else (self.first_reduce+self.group_for_reduces if opt.op in [OptOps.GROUP, OptOps.GROUPTOP] else 0)) # noqa: E501
|
||||
axis = opt.axis + (self.first_reduce if opt.op is OptOps.UNROLL else (self.first_reduce+self.group_for_reduces if opt.op in [OptOps.GROUP, OptOps.GROUPTOP] else 0)) # noqa: E501
|
||||
else: axis = -1
|
||||
check(axis < len(self.full_shape), "invalid axis")
|
||||
|
||||
if opt.amt is not None:
|
||||
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
|
||||
check(isinstance(amt, int) and amt != 1, "shift/padto of amt 1 or Node is meaningless")
|
||||
if opt.op != OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift")
|
||||
if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift")
|
||||
else: amt = -1
|
||||
|
||||
if self.reduceop and (opt.op in [OptOps.GROUP, OptOps.GROUPTOP] or (self.group_for_reduces and opt.op not in [OptOps.NOLOCALS, OptOps.PADTO])):
|
||||
if self.reduceop and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
|
||||
acc_sz = dt.base.itemsize if isinstance((dt:=get_lazyop_info(self.reduceop).dtype), ImageDType) else dt.itemsize
|
||||
upcast_sz = prod(self.full_shape[self.shape_len-self.upcasted:])
|
||||
local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces])
|
||||
check(amt*acc_sz*upcast_sz*local_sz <= self.opts.shared_max, "exceeds maximum shared memory size")
|
||||
|
||||
if opt.op == OptOps.LOCAL: # cyan
|
||||
if opt.op is OptOps.LOCAL: # cyan
|
||||
check(self.opts.has_local, "target does not support local")
|
||||
check(axis < self.global_dims, "local is for globals")
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims)
|
||||
self.local_dims += 1
|
||||
elif opt.op in [OptOps.GROUP, OptOps.GROUPTOP]: # green
|
||||
elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green
|
||||
check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
|
||||
check(axis >= self.first_reduce + self.group_for_reduces and axis < self.shape_len-self.upcasted, "must be reduce axis to group")
|
||||
check(not self.tensor_core, "can't group with tensor cores")
|
||||
self.shift_to(axis, amt, top=(opt.op==OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces)
|
||||
self.shift_to(axis, amt, top=(opt.op is OptOps.GROUPTOP), insert_before=self.first_reduce + self.group_for_reduces)
|
||||
self.group_for_reduces += 1
|
||||
elif opt.op == OptOps.UNROLL: # purple
|
||||
elif opt.op is OptOps.UNROLL: # purple
|
||||
check(axis < self.shape_len-self.upcasted, "can't upcasted already upcasted")
|
||||
check(amt <= 32, "don't unroll more than 32")
|
||||
# TODO: fix upcast_count to put purples before yellows. broken because of METAL tensor cores
|
||||
@@ -451,13 +451,13 @@ class Kernel:
|
||||
if self.full_shape[axis] == amt and axis < self.first_reduce+self.group_for_reduces: self.group_for_reduces -= 1 # fully unrolling a GROUP
|
||||
self.shift_to(axis, amt, insert_before=None)
|
||||
self.upcast()
|
||||
elif opt.op == OptOps.UPCAST: # yellow
|
||||
elif opt.op is OptOps.UPCAST: # yellow
|
||||
check(axis < self.first_reduce, "upcast is for non-reduce")
|
||||
check(not(self.tensor_core and axis >= self.first_reduce-len(self.tensor_core.threads)), "can't upcast TC locals")
|
||||
check(amt <= 8, "don't upcast more than 8")
|
||||
self.shift_to(axis, amt, insert_before=None)
|
||||
self.upcast()
|
||||
elif opt.op == OptOps.UPCASTMID: # white
|
||||
elif opt.op is OptOps.UPCASTMID: # white
|
||||
check(self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduces != 0 and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1, "invalid upcast mid reduce") # noqa: E501
|
||||
axes = self.sts[0].unit_stride_axes()
|
||||
check(len(axes) == 1, f"wrong number of stride 1 axis : {axes}")
|
||||
@@ -465,11 +465,11 @@ class Kernel:
|
||||
check(amt == 4, "don't upcast mid anything but 4")
|
||||
self.shift_to(axis, amt, insert_before=self.first_reduce + self.group_for_reduces)
|
||||
self.group_for_reduces += 1
|
||||
elif opt.op == OptOps.NOLOCALS:
|
||||
elif opt.op is OptOps.NOLOCALS:
|
||||
check(self.opts.has_local and not self.dont_use_locals, "NOLOCALS is meaningless if target does not support local or already not using locals")
|
||||
check(self.local_dims == 0 and self.group_for_reduces == 0, "can't have no locals with locals")
|
||||
self.dont_use_locals = True
|
||||
elif opt.op == OptOps.PADTO:
|
||||
elif opt.op is OptOps.PADTO:
|
||||
check(not self.vars, "does not work with symbolic shape")
|
||||
check(axis < self.first_reduce, "cannot pad a reduce axis")
|
||||
padded = False
|
||||
@@ -498,8 +498,8 @@ class Kernel:
|
||||
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
|
||||
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
|
||||
if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
|
||||
self.reduceop and self.reduceop.op == ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \
|
||||
(mulop:=self.reduceop.src[0]).op == BinaryOps.MUL and mulop.src[0].op == BufferOps.LOAD and mulop.src[1].op == BufferOps.LOAD:
|
||||
self.reduceop and self.reduceop.op is ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \
|
||||
(mulop:=self.reduceop.src[0]).op is BinaryOps.MUL and mulop.src[0].op is BufferOps.LOAD and mulop.src[1].op is BufferOps.LOAD:
|
||||
st0, st1 = self.sts[self.bufs.index(mulop.src[0].arg)], self.sts[self.bufs.index(mulop.src[1].arg)]
|
||||
strides0, strides1 = st0.real_strides(), st1.real_strides()
|
||||
def has_expanded_axis(shape, strides): return any(s > 1 and st == 0 for s,st in zip(shape,strides))
|
||||
|
||||
@@ -53,8 +53,8 @@ class Linearizer(Kernel):
|
||||
info = get_lazyop_info(reduceop)
|
||||
assert all(0 <= x < len(info.shape) for x in reduceop.arg), "arg axis out of range"
|
||||
dtype = info.dtype
|
||||
if reduceop.op == ReduceOps.SUM: return 0.0 if dtypes.is_float(dtype) else 0
|
||||
elif reduceop.op == ReduceOps.MAX:
|
||||
if reduceop.op is ReduceOps.SUM: return 0.0 if dtypes.is_float(dtype) else 0
|
||||
elif reduceop.op is ReduceOps.MAX:
|
||||
if dtypes.is_int(dtype): return 0 if dtypes.is_unsigned(dtype) else -2**(dtype.itemsize*8-1)
|
||||
return -math.inf if dtypes.is_float(dtype) else False
|
||||
|
||||
@@ -410,7 +410,7 @@ class Linearizer(Kernel):
|
||||
if cache is None: cache = {}
|
||||
if x in cache: return cache[x]
|
||||
if x.op in BufferOps: return loaded_buffers[x.arg]
|
||||
if x.op == UnaryOps.CAST: return [self.uops.add(UOps.BITCAST if x.arg[1] else UOps.CAST, self.get_base_dtype(x.arg[0]), (u,), x.arg[0]) \
|
||||
if x.op is UnaryOps.CAST: return [self.uops.add(UOps.BITCAST if x.arg[1] else UOps.CAST, self.get_base_dtype(x.arg[0]), (u,), x.arg[0]) \
|
||||
for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)]
|
||||
if x.op in ReduceOps and not do_reduce:
|
||||
assert offs is None, "not available if we aren't doing reduce"
|
||||
@@ -428,6 +428,6 @@ class Linearizer(Kernel):
|
||||
if input_acc[off] != acc[off]:
|
||||
acc[off] = self.uops.add(UOps.PHI, input_acc[off].dtype, (input_acc[off], acc[off]) + tuple(loop_ctx))
|
||||
else:
|
||||
ret = [self.uops.add(UOps.ALU, dtypes.bool if x.op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else val[-1].dtype, val, x.op) for val in zip(*values)]
|
||||
ret = [self.uops.add(UOps.ALU, dtypes.bool if x.op in {BinaryOps.CMPLT, BinaryOps.CMPEQ} else val[-1].dtype, val, x.op) for val in zip(*values)]
|
||||
cache[x] = ret
|
||||
return ret
|
||||
|
||||
@@ -57,8 +57,8 @@ def exec_alu(arg, dtype, p): return truncate[dtype](python_alu[arg](*p))
|
||||
def uop_alu_resolve(u:UOp) -> sint:
|
||||
if u.uop is UOps.CONST: return u.arg
|
||||
elif u.uop is UOps.DEFINE_VAR: return u.arg
|
||||
elif u.uop is UOps.ALU and u.arg == BinaryOps.MUL: return uop_alu_resolve(u.vin[0]) * uop_alu_resolve(u.vin[1])
|
||||
elif u.uop is UOps.ALU and u.arg == BinaryOps.ADD: return uop_alu_resolve(u.vin[0]) + uop_alu_resolve(u.vin[1])
|
||||
elif u.uop is UOps.ALU and u.arg is BinaryOps.MUL: return uop_alu_resolve(u.vin[0]) * uop_alu_resolve(u.vin[1])
|
||||
elif u.uop is UOps.ALU and u.arg is BinaryOps.ADD: return uop_alu_resolve(u.vin[0]) + uop_alu_resolve(u.vin[1])
|
||||
else: raise RuntimeError(f"ALU resolve fail @ {u.uop}")
|
||||
|
||||
def _match(uop:UOp, pattern:Dict[str, Any], store:Dict[str, UOp]) -> bool:
|
||||
@@ -198,7 +198,7 @@ class UOpGraph:
|
||||
while ssize != len(deps):
|
||||
ssize = len(deps)
|
||||
for u in self.uops:
|
||||
if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])):
|
||||
if len(deps.intersection([x for x in u.vin if x.uop is not UOps.PHI])):
|
||||
deps.add(u)
|
||||
return deps
|
||||
|
||||
@@ -237,16 +237,16 @@ class UOpGraph:
|
||||
|
||||
def simplify_phi_loops(self, get_recursive_parents):
|
||||
def alu_opposite(arg, x, y):
|
||||
if arg == BinaryOps.ADD: return x - y
|
||||
elif arg == BinaryOps.MUL: return Node.__floordiv__(x, y, False)
|
||||
if arg is BinaryOps.ADD: return x - y
|
||||
elif arg is BinaryOps.MUL: return Node.__floordiv__(x, y, False)
|
||||
else: raise RuntimeError("unhandled alu")
|
||||
def to_symbolic(u: UOp):
|
||||
if u.uop == UOps.CONST: return NumNode(int(u.arg))
|
||||
if u.uop is UOps.CONST: return NumNode(int(u.arg))
|
||||
elif u.uop in {UOps.LOOP, UOps.SPECIAL}:
|
||||
if u not in seen_vars: seen_vars[u] = u.arg[1] if u.uop is UOps.SPECIAL else "loop{}".format(len(seen_vars))
|
||||
return Variable(seen_vars[u], u.vin[0].arg, u.vin[1].arg-1) if u.uop is UOps.LOOP else Variable(seen_vars[u], 0, u.arg[2]-1)
|
||||
elif u.uop == UOps.ALU and u.arg == BinaryOps.ADD: return to_symbolic(u.vin[0]) + to_symbolic(u.vin[1])
|
||||
elif u.uop == UOps.ALU and u.arg == BinaryOps.MUL: return to_symbolic(u.vin[0]) * to_symbolic(u.vin[1])
|
||||
elif u.uop is UOps.ALU and u.arg is BinaryOps.ADD: return to_symbolic(u.vin[0]) + to_symbolic(u.vin[1])
|
||||
elif u.uop is UOps.ALU and u.arg is BinaryOps.MUL: return to_symbolic(u.vin[0]) * to_symbolic(u.vin[1])
|
||||
else: raise RuntimeError("unhandled op: {}".format(u))
|
||||
def loop_factor(with_loop: UOp, factored: Node, loop_op, round_up=False):
|
||||
if with_loop == loop_op: return factored
|
||||
|
||||
@@ -100,7 +100,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffe
|
||||
if buf.forced_realize: realizes.add(buf)
|
||||
allbufs[buf] = None
|
||||
if buf.op in LoadOps: realizes.add(buf.base)
|
||||
if buf.op == LoadOps.COPY:
|
||||
if buf.op is LoadOps.COPY:
|
||||
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
|
||||
realizes.add(buf.srcs[0].base)
|
||||
for x in buf.srcs:
|
||||
|
||||
@@ -49,7 +49,7 @@ top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", Bin
|
||||
TernaryOps: "#c0c0c0", BufferOps: '#a0a0ff'}
|
||||
def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
|
||||
init_graph()
|
||||
if lb.base.realized is None and lb.base.op == LoadOps.CONST: return
|
||||
if lb.base.realized is None and lb.base.op is LoadOps.CONST: return
|
||||
if lb.base != lb:
|
||||
offset = lb.st.expr_idxs([NumNode(0)] * len(lb.st.shape))[0]
|
||||
label = f"{lb.st.shape}\n{lb.st.real_strides()}" + (f"\n{offset}" if offset != 0 else "")
|
||||
@@ -60,7 +60,7 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
|
||||
label_append = []
|
||||
for idx,x in enumerate(lb.srcs):
|
||||
if nm(x) not in G.nodes: log_lazybuffer(x)
|
||||
if x.base.realized is None and x.base.op == LoadOps.CONST:
|
||||
if x.base.realized is None and x.base.op is LoadOps.CONST:
|
||||
label_append.append(f"\nCONST{idx} {x.base.arg}")
|
||||
else:
|
||||
G.add_edge(nm(x), nm(lb), color='#a0a0a0')
|
||||
|
||||
@@ -92,8 +92,8 @@ class LazyBuffer:
|
||||
new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, UnaryOps.CAST, (dtype, bitcast), (self,))
|
||||
|
||||
def is_unrealized_const(self): return not self.base.realized and self.base.op is LoadOps.CONST
|
||||
def is_unrealized_contiguous_const(self): return self.base == self and not self.base.realized and self.op is LoadOps.CONST
|
||||
def is_unrealized_const(self): return self.base.realized is None and self.base.op is LoadOps.CONST
|
||||
def is_unrealized_contiguous_const(self): return self.base == self and self.base.realized is None and self.op is LoadOps.CONST
|
||||
|
||||
def _copy(self, device:str) -> LazyBuffer:
|
||||
if (dstart:=self.device.split(":")[0]) in {"EXT", "DISK"} or (dstart in {"HSA", "CUDA"} and device.split(":")[0] == dstart):
|
||||
|
||||
@@ -89,7 +89,7 @@ InterpretedFlopCounter: Dict[Op, Callable] = {
|
||||
BufferOps.CONST: lambda arg: FlopCounter(arg.st.shape, arg.dtype, 0, {}),
|
||||
BufferOps.STORE: lambda self,arg: FlopCounter(arg.st.shape, arg.dtype, self.consume_flops(), {**self.mem, arg.idx: arg.dtype.itemsize*arg.st.real_size()}), # noqa: E501
|
||||
UnaryOps.CAST: lambda self,arg: FlopCounter(self.shape, arg[0], self.consume_flops(), self.mem), # cast uses no flops
|
||||
**{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op != UnaryOps.CAST}, # noqa: E501
|
||||
**{op:lambda self: FlopCounter(self.shape, self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in UnaryOps if op is not UnaryOps.CAST}, # noqa: E501
|
||||
**{op:lambda self,y,op=op: FlopCounter(self.shape, dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else self.dtype, self.consume_flops() + y.consume_flops() + prod(self.shape), {**self.mem, **y.mem}) for op in BinaryOps}, # noqa: E501
|
||||
**{op:lambda self,axis: FlopCounter(tuple(1 if i in axis else s for i,s in enumerate(self.shape)), self.dtype, self.consume_flops() + prod(self.shape), self.mem) for op in ReduceOps}, # noqa: E501
|
||||
TernaryOps.WHERE: lambda self,y,z: FlopCounter(self.shape, y.dtype, self.consume_flops() + y.consume_flops() + z.consume_flops() + prod(self.shape), {**self.mem, **y.mem, **z.mem})} # noqa: E501
|
||||
|
||||
@@ -66,10 +66,10 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
|
||||
def ptr_ar(root):
|
||||
assert root.arg in {'.shared', '.global', None}
|
||||
if root.arg is None: root.arg = '.shared' if root.vin[0].uop == UOps.DEFINE_LOCAL else '.global' # move this to the argL
|
||||
if root.arg is None: root.arg = '.shared' if root.vin[0].uop is UOps.DEFINE_LOCAL else '.global' # move this to the argL
|
||||
val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=root.vin[0].dtype.itemsize, insert_before=uops.uops.index(root))
|
||||
ptr = uops.add(UOps.ALU, dtypes.int, (root.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(root))
|
||||
if ptr.uop == UOps.CONST: root.vin = (root.vin[0], ptr) + root.vin[2:]
|
||||
if ptr.uop is UOps.CONST: root.vin = (root.vin[0], ptr) + root.vin[2:]
|
||||
else:
|
||||
zero = uops.add(UOps.CONST, dtypes.int, tuple(), arg=0, cachable=False, insert_before=uops.uops.index(root))
|
||||
bptr = uops.add(UOps.CAST, dtypes.uint64, (ptr,), insert_before=uops.uops.index(root))
|
||||
@@ -133,17 +133,17 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
if uop == UOps.IF:
|
||||
if uop is UOps.IF:
|
||||
assert vin[0].dtype is not None
|
||||
kk(*lang.render_bra(lb:=ssa_label(u, 'if'), cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
|
||||
elif uop == UOps.BARRIER and lang.barrier: kk(lang.barrier)
|
||||
elif uop == UOps.ENDLOOP:
|
||||
elif uop is UOps.BARRIER and lang.barrier: kk(lang.barrier)
|
||||
elif uop is UOps.ENDLOOP:
|
||||
kk(lang.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, lang.types[dtypes.int]),
|
||||
lang.asm_for_op[BinaryOps.CMPLT](pred:=ssa(None, "pred", "pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, lang.types[dtypes.int]))
|
||||
kk(*lang.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
|
||||
elif uop == UOps.ENDIF:
|
||||
elif uop is UOps.ENDIF:
|
||||
kk(f"{r_label[vin[0]]}:")
|
||||
elif uop == UOps.STORE:
|
||||
elif uop is UOps.STORE:
|
||||
assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None
|
||||
if vin[2].dtype.count > 1:
|
||||
kk((f"@{r[vin[3]]} " if len(vin)>3 else "") + \
|
||||
@@ -152,8 +152,8 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
kk(*lang.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=u.arg, offset=vin[1].arg))
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop == UOps.LOOP: kk(*lang.render_loop(ssa(u, 'ridx'), r[vin[0]], ssa_label(u, 'loop')))
|
||||
elif uop == UOps.ALU:
|
||||
if uop is UOps.LOOP: kk(*lang.render_loop(ssa(u, 'ridx'), r[vin[0]], ssa_label(u, 'loop')))
|
||||
elif uop is UOps.ALU:
|
||||
assert vin[0].dtype is not None
|
||||
operands = [r[x] for x in vin]
|
||||
lab = ssa(u, "alu")
|
||||
@@ -163,28 +163,28 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
for i, op in enumerate(operands):
|
||||
operands[i] = ssa(None, "alu_cast", lang.types[dtype])
|
||||
kk(*lang.render_cast(operands[i], op, dtype, dtypes.half)) # type: ignore
|
||||
if args == BinaryOps.CMPLT or args == BinaryOps.CMPEQ:
|
||||
if args is BinaryOps.CMPLT or args is BinaryOps.CMPEQ:
|
||||
# pass in the other dtype here
|
||||
kk(lang.asm_for_op[args](lab, *operands, vin[0].dtype, lang.types[vin[0].dtype]))
|
||||
else:
|
||||
kk(lang.asm_for_op[args](lab, *operands, dtype, lang.types[dtype]))
|
||||
if needs_upcast:
|
||||
kk(*lang.render_cast(out_lab, lab, dtypes.half, dtype))
|
||||
elif uop == UOps.DEFINE_ACC:
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa(None, 'acc', lang.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
for uu in r[u]: kk(f"mov.b{lang.types[dtype.scalar()][1:]} {uu}, {const(args, dtype.scalar())};")
|
||||
else: kk(f"mov.b{lang.types[dtype][1:]} {ssa(u, 'acc')}, {const(args, dtype)};")
|
||||
elif uop == UOps.SPECIAL:
|
||||
elif uop is UOps.SPECIAL:
|
||||
assert args[1][0] != "i", "idx not supported"
|
||||
kk(f"mov.u32 %{args[1]}, {(lang.gid if args[1][0] == 'g' else lang.lid)[args[0]]};")
|
||||
r[u] = "%" + args[1]
|
||||
kernel = [f".reg .u32 %{args[1]};"] + kernel
|
||||
elif uop == UOps.CONST:
|
||||
elif uop is UOps.CONST:
|
||||
if dtype.count > 1: r[u] = [const(args, dtype.scalar(), mov=True) for _ in range(dtype.count)]
|
||||
else: r[u] = const(args, dtype, mov=True)
|
||||
elif uop == UOps.GEP: r[u] = r[vin[0]][u.arg]
|
||||
elif uop == UOps.LOAD:
|
||||
elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg]
|
||||
elif uop is UOps.LOAD:
|
||||
assert vin[1].dtype is not None
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa(None, 'val', lang.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
@@ -195,14 +195,14 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
else:
|
||||
kk(*lang.render_load(r[vin[0]], ssa(u, 'val'), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
|
||||
alt=r[vin[3]] if len(vin) > 3 else None, ss=u.arg, offset=vin[1].arg))
|
||||
elif uop == UOps.PHI:
|
||||
elif uop is UOps.PHI:
|
||||
kk(f"mov.b{lang.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
|
||||
r[u] = r[vin[0]]
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
assert vin[0].dtype is not None
|
||||
if dtype.count>1: r[u] = [r[x] for x in vin] # type: ignore
|
||||
else: cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
elif uop is UOps.DEFINE_LOCAL:
|
||||
# TODO: we should sum these, and fetch 0xC000 from somewhere
|
||||
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
|
||||
kk(*lang.render_local(ssa(u, 'local', lang.types[dtypes.ulong]), args[0], args[1], dtype))
|
||||
|
||||
@@ -131,7 +131,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
|
||||
val = lang.code_for_op[args](*operands, dtype)
|
||||
assert child_count[u] != 0, f"childless ALU op found {u}"
|
||||
# TODO: fix index rendering issue. fix clang nested max macro issue
|
||||
if child_count[u] <= 1 and args != BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
|
||||
if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
|
||||
else: kk(f"{lang.render_dtype(dtype)} {ssa(u,'alu')} = {val};")
|
||||
elif uop is UOps.SPECIAL:
|
||||
kk(f"int {args[1]} = {lang.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
|
||||
@@ -215,7 +215,7 @@ class MetalLanguage(CStyleLanguage):
|
||||
return f"as_type<{self.render_dtype(var_dtype)}>({x[0]})" if bitcast else super().render_cast(x, var_dtype)
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
|
||||
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.uop == UOps.WMMA])
|
||||
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.uop is UOps.WMMA])
|
||||
for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{
|
||||
simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x;
|
||||
b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
|
||||
@@ -257,7 +257,7 @@ class CUDALanguage(CStyleLanguage):
|
||||
prefix += ["#include <cuda_bf16.h>"] + [_make_cuda_dtype("nv_bfloat16", "bfloat16", x) for x in [4, 8]]
|
||||
|
||||
# TODO: this has to be way better to generate for arbitrary M,N,K: use arg[1] for MNK, use arg[4] for vec sizes, encode register packing
|
||||
for arg in set([uop.arg for uop in uops if uop.uop == UOps.WMMA]):
|
||||
for arg in set([uop.arg for uop in uops if uop.uop is UOps.WMMA]):
|
||||
fn, ti, to, ci, co = arg[0], dt_map[arg[2]][0], dt_map[arg[3]][0], dt_map[arg[2]][1], dt_map[arg[3]][1]
|
||||
prefix.append(f"""__device__ {to}4 __{fn}({ti}8 a, {ti}4 b, {to}4 c) {{ int *a_pk = (int *) (&a), *b_pk = (int *) (&b);
|
||||
asm( "mma.sync.aligned.m16n8k16.row.col.{co}.{ci}.{ci}.{co} {{ %0, %1, %2, %3 }}, {{ %4, %5, %6, %7 }}, {{ %8, %9 }}, {{ %0, %1, %2, %3 }};"
|
||||
@@ -338,7 +338,7 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
|
||||
|
||||
prefix += [_make_hip_dtype(*x) for x in vec_dts]
|
||||
|
||||
for arg in set([uop.arg for uop in uops if uop.uop == UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
|
||||
for arg in set([uop.arg for uop in uops if uop.uop is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
|
||||
if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")
|
||||
else: prefix.append(f"static __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
|
||||
half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
|
||||
@@ -347,7 +347,7 @@ static __attribute__((device)) bool operator==(hip_bfloat16 a, hip_bfloat16 b) {
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
|
||||
def get_kernel_modifier(self, uops:UOpGraph) -> str:
|
||||
requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.uop == UOps.SPECIAL and u.arg[1][0] == "l")
|
||||
requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.uop is UOps.SPECIAL and u.arg[1][0] == "l")
|
||||
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
|
||||
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
|
||||
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
|
||||
|
||||
@@ -109,7 +109,7 @@ def uops_to_llvm_ir(function_name:str, uops:UOpGraph) -> str:
|
||||
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[vin[0].vin[1]]), loop_entry_bb, bb[-1].block)
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop == UOps.LOOP:
|
||||
if uop is UOps.LOOP:
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
|
||||
bb[-2].branch(bb[-1].block)
|
||||
|
||||
@@ -138,7 +138,7 @@ def uops_to_llvm_ir(function_name:str, uops:UOpGraph) -> str:
|
||||
lvars[u] = lvars[vin[1]]
|
||||
# PHI UOps can link to other PHI Uops, backtrace this to DEFINE_ACC
|
||||
backward = vin[0]
|
||||
while backward.uop == UOps.PHI: backward = backward.vin[0]
|
||||
while backward.uop is UOps.PHI: backward = backward.vin[0]
|
||||
lvars[backward] = lvars[u]
|
||||
elif uop is UOps.ALU:
|
||||
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin], dtype if args not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else vin[0].dtype)
|
||||
|
||||
@@ -48,17 +48,17 @@ class DiskRunner(JITRunner):
|
||||
skip_allocation = True
|
||||
def __init__(self, ast:LazyOp):
|
||||
# two ASTs are allowed here.
|
||||
assert ast.op == BufferOps.STORE, "output of AST must be store"
|
||||
assert ast.op is BufferOps.STORE, "output of AST must be store"
|
||||
assert ast.arg.st.contiguous, "shapetracker must be contiguous"
|
||||
# TODO: there shouldn't actually be casts here, bitcasts should fold into the load
|
||||
if ast.src[0].op == UnaryOps.CAST:
|
||||
if ast.src[0].op is UnaryOps.CAST:
|
||||
top_src = ast.src[0].src[0]
|
||||
assert ast.src[0].arg[1], "disk only supports bitcasts, not normal casts"
|
||||
self.new_dtype = ast.src[0].arg[0]
|
||||
else:
|
||||
top_src = ast.src[0]
|
||||
self.new_dtype = top_src.arg.dtype
|
||||
assert top_src.op == BufferOps.LOAD, "top of AST must be load"
|
||||
assert top_src.op is BufferOps.LOAD, "top of AST must be load"
|
||||
assert len(top_src.arg.st.views) == 1, "shapetracker must have 1 view"
|
||||
view = top_src.arg.st.views[0]
|
||||
assert view.mask is None, "view cannot have a mask"
|
||||
|
||||
Reference in New Issue
Block a user