diff --git a/docs/abstractions2.py b/docs/abstractions2.py index 42f33163a6..e2690baab8 100644 --- a/docs/abstractions2.py +++ b/docs/abstractions2.py @@ -39,7 +39,7 @@ DEVICE = "CLANG" # NOTE: you can change this! import struct from tinygrad.dtype import dtypes from tinygrad.device import Buffer, Device -from tinygrad.ops import BinaryOps, MetaOps, UOp, UOps +from tinygrad.ops import BinaryOps, MetaOps, UOp, Ops from tinygrad.shape.shapetracker import ShapeTracker # allocate some buffers + load in values @@ -49,14 +49,14 @@ b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struc # NOTE: a._buf is the same as the return from MallocAllocator.alloc # describe the computation -buf_1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 1) -buf_2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 2) -ld_1 = UOp(UOps.LOAD, dtypes.int32, (buf_1, ShapeTracker.from_shape((1,)).to_uop())) -ld_2 = UOp(UOps.LOAD, dtypes.int32, (buf_2, ShapeTracker.from_shape((1,)).to_uop())) +buf_1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 1) +buf_2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 2) +ld_1 = UOp(Ops.LOAD, dtypes.int32, (buf_1, ShapeTracker.from_shape((1,)).to_uop())) +ld_2 = UOp(Ops.LOAD, dtypes.int32, (buf_2, ShapeTracker.from_shape((1,)).to_uop())) alu = ld_1 + ld_2 -output_buf = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0) -st_0 = UOp(UOps.STORE, dtypes.void, (output_buf, ShapeTracker.from_shape((1,)).to_uop(), alu)) -s = UOp(UOps.SINK, dtypes.void, (st_0,)) +output_buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0) +st_0 = UOp(Ops.STORE, dtypes.void, (output_buf, ShapeTracker.from_shape((1,)).to_uop(), alu)) +s = UOp(Ops.SINK, dtypes.void, (st_0,)) # convert the computation to a "linearized" format (print the format) from tinygrad.engine.realize import get_kernel, CompiledRunner diff --git a/docs/developer/uop.md b/docs/developer/uop.md index d529aff972..063c475bc5 100644 --- a/docs/developer/uop.md +++ b/docs/developer/uop.md @@ -4,7 +4,7 @@ members_order: source show_labels: false -::: tinygrad.ops.UOps +::: tinygrad.ops.Ops options: members: true members_order: source diff --git a/examples/handcode_opt.py b/examples/handcode_opt.py index 5f951c5e01..8d2e2833cd 100644 --- a/examples/handcode_opt.py +++ b/examples/handcode_opt.py @@ -4,7 +4,7 @@ from extra.mcts_search import mcts_search from examples.mlperf.helpers import get_mlperf_bert_model from tinygrad import Tensor, Device, dtypes, nn from tinygrad.codegen.kernel import Kernel -from tinygrad.ops import UOps, sym_infer +from tinygrad.ops import Ops, sym_infer from tinygrad.device import Compiled from tinygrad.engine.schedule import create_schedule from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin @@ -66,7 +66,7 @@ if __name__ == "__main__": print(f"optimizing for {Device.DEFAULT}") sched = globals()[f"get_sched_{getenv('MODEL', 'resnet')}"]() - sched = [x for x in sched if x.ast.op is UOps.SINK] + sched = [x for x in sched if x.ast.op is Ops.SINK] # focus on one kernel if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1] diff --git a/examples/llm.c/export.py b/examples/llm.c/export.py index 0f9cc3accb..46e36dddfc 100755 --- a/examples/llm.c/export.py +++ b/examples/llm.c/export.py @@ -8,7 +8,7 @@ from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GlobalCou from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import get_kernel, run_schedule from tinygrad.engine.memory import memory_planner -from tinygrad.ops import MetaOps, UOps +from tinygrad.ops import MetaOps, Ops TIMING = getenv("TIMING") @@ -41,7 +41,7 @@ if __name__ == "__main__": print(f"calls {i}:", len(sched)) #run_schedule(sched[:]) sched = memory_planner(sched) - ast_dedup = dedup([si.ast for si in sched if si.ast.op is UOps.SINK]) + ast_dedup = dedup([si.ast for si in sched if si.ast.op is Ops.SINK]) srcs = {} for ast in ast_dedup: k = get_kernel(Device["CLANG"].renderer, ast) @@ -82,7 +82,7 @@ if __name__ == "__main__": for i,si in enumerate(sched): bufs = [(named_buffers.get(b, f"b{numbered_bufs[b]}"), b) for b in si.bufs] all_bufs += bufs - if si.ast.op is not UOps.SINK: + if si.ast.op is not Ops.SINK: print(f"// {si.ast.op}", bufs) else: print(f"{srcs[si.ast][0]}({', '.join([x[0] for x in bufs])})") diff --git a/examples/openpilot/compile2.py b/examples/openpilot/compile2.py index cb247aa6aa..d862bf203c 100644 --- a/examples/openpilot/compile2.py +++ b/examples/openpilot/compile2.py @@ -19,7 +19,7 @@ from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG, tqdm from tinygrad.engine.realize import run_schedule, lower_schedule, ExecItem, CompiledRunner from tinygrad.engine.memory import memory_planner from tinygrad.engine.schedule import ScheduleItem, create_schedule -from tinygrad.ops import UOps +from tinygrad.ops import Ops from tinygrad.tensor import _to_np_dtype Device.DEFAULT = "GPU" @@ -50,7 +50,7 @@ def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]: print(f"{len(schedule)} schedule items depend on the input, {len(schedule_independent)} don't") # confirm no non-sink metaop in the (non independent) schedule except for the ones that load the input buffers - assert all(si.ast.op is UOps.SINK or out in input_lb for si in schedule for out in si.outputs), "has non SINK ops, can't compile to Thneed" + assert all(si.ast.op is Ops.SINK or out in input_lb for si in schedule for out in si.outputs), "has non SINK ops, can't compile to Thneed" return schedule, schedule_independent, inputs def test_vs_onnx(onnx_data, eis:Optional[List[ExecItem]], inputs:Dict[str, Tensor]): @@ -106,7 +106,7 @@ if __name__ == "__main__": #exit(0) schedule, schedule_independent, inputs = get_schedule(onnx_data) - schedule, schedule_input = partition(schedule, lambda x: x.ast.op is UOps.SINK) + schedule, schedule_input = partition(schedule, lambda x: x.ast.op is Ops.SINK) print(f"{len(schedule_input)} inputs") run_schedule(schedule_independent) diff --git a/extra/assembly/assembly.py b/extra/assembly/assembly.py index 1133675d23..666282c05a 100644 --- a/extra/assembly/assembly.py +++ b/extra/assembly/assembly.py @@ -1,5 +1,5 @@ from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast -from tinygrad.codegen.kernel import UOps, MemOp, UOp +from tinygrad.codegen.kernel import Ops, MemOp, UOp from tinygrad.ops import BinaryOps, UnaryOps from tinygrad.dtype import DType, dtypes from tinygrad.helpers import DEBUG @@ -23,7 +23,7 @@ class Register(NamedTuple): return [] class AssemblyInstruction(NamedTuple): - op: UOps + op: Ops out: Optional[Register] vin: List[Union[Register, int, float]] arg: Any = None @@ -49,21 +49,21 @@ class AssemblyLanguage: def render_numnode(self, b) -> Register: key = ("num", b) - if key not in self.tor: self.ins.append(AssemblyInstruction(UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b)) + if key not in self.tor: self.ins.append(AssemblyInstruction(Ops.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b)) return self.tor[key] def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register: key = (op, a, b) if key not in self.tor: #if not isinstance(b, Register): b = render_numnode(b) - self.ins.append(AssemblyInstruction(UOps.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op)) + self.ins.append(AssemblyInstruction(Ops.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op)) return self.tor[key] def render_cast(self, a:Register, new_dtype:DType) -> Register: if a.dtype == new_dtype: return a key = (a, new_dtype) if key not in self.tor: - self.ins.append(AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a])) + self.ins.append(AssemblyInstruction(Ops.CAST, self.newreg(key, dtype=new_dtype), [a])) return self.tor[key] render_ops: Any = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b), @@ -87,7 +87,7 @@ class AssemblyLanguage: if self.supports_load3: if reg.scalar: new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype) - self.ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP)) + self.ins.append(AssemblyInstruction(Ops.ALU, new_reg, [reg], UnaryOps.NOOP)) reg = new_reg return self.tor[args.name], reg, off reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64) @@ -98,91 +98,91 @@ def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]): lang.ins.clear() lang.tor.clear() lang.cnts.clear() - buf_to_dtype = {args:dtype for uop,dtype,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL} + buf_to_dtype = {args:dtype for uop,dtype,_,args,_ in uops if uop == Ops.DEFINE_GLOBAL} global_size, local_size = [], [] skipload_branch = 0 - lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype] + lang.ins += [AssemblyInstruction(Ops.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype] for u in uops: uop,dtype,vin,args,_ = u - if uop == UOps.DEFINE_LOCAL: - lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args)) - lang.ins.append(AssemblyInstruction(UOps.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP)) - elif uop == UOps.LOOP: + if uop == Ops.DEFINE_LOCAL: + lang.ins.append(AssemblyInstruction(Ops.DEFINE_LOCAL, None, [], args)) + lang.ins.append(AssemblyInstruction(Ops.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP)) + elif uop == Ops.LOOP: if args[1] == "global": for i,var in enumerate(args[0]): global_size.append(var.max+1) - lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}")) + lang.ins.append(AssemblyInstruction(Ops.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}")) elif args[1] == "local": for i,var in enumerate(args[0]): local_size.append(var.max+1) - lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}")) + lang.ins.append(AssemblyInstruction(Ops.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}")) else: for var in args[0]: if not isinstance(var, NumNode): # TODO: why is this coming through? - lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0)) - lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr)) - elif uop == UOps.ENDLOOP: + lang.ins.append(AssemblyInstruction(Ops.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0)) + lang.ins.append(AssemblyInstruction(Ops.LABEL, None, [], "$loop_"+var.expr)) + elif uop == Ops.ENDLOOP: if args[1] not in ["global", "local", "global+local"]: for var in reversed(args[0]): if not isinstance(var, NumNode): # TODO: why is this coming through? - lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD)) + lang.ins.append(AssemblyInstruction(Ops.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD)) pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool) - lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True))) + lang.ins.append(AssemblyInstruction(Ops.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True))) elif args[1] == "global+local": for i, var in enumerate(reversed(args[0])): - lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}"))) + lang.ins.append(AssemblyInstruction(Ops.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}"))) elif args[1] == 'local': for i, var in enumerate(reversed(args[0])): - lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}"))) - elif uop == UOps.CAST: + lang.ins.append(AssemblyInstruction(Ops.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}"))) + elif uop == Ops.CAST: # TODO: we should reconsider outputting CAST in the linearizer. these are needless copies out = lang.newreg(u, dtype) for i,sr in enumerate(out.subregs()): - lang.ins.append(AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP)) - elif uop == UOps.ALU: + lang.ins.append(AssemblyInstruction(Ops.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP)) + elif uop == Ops.ALU: out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u] # this is the only thing that can violate SSA if args in [BinaryOps.CMPLT]: pred_reg = lang.newreg((u, 'pred'), dtype=dtypes.bool) - lang.ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args)) - lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args)) + lang.ins.append(AssemblyInstruction(Ops.ALU, pred_reg, [lang.tor[x] for x in vin], args)) + lang.ins.append(AssemblyInstruction(Ops.CAST, out, [pred_reg], args)) elif args == BinaryOps.DIV and lang.no_div: tmp = lang.newreg((u, "rcp")) - lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP)) - lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL)) + lang.ins.append(AssemblyInstruction(Ops.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP)) + lang.ins.append(AssemblyInstruction(Ops.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL)) elif args == UnaryOps.SIN and lang.sin_is_sin2pi: tmp = lang.newreg((u, "2pi")) - lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL)) - lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args)) + lang.ins.append(AssemblyInstruction(Ops.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL)) + lang.ins.append(AssemblyInstruction(Ops.ALU, out, [tmp], args)) else: - lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args)) - elif uop == UOps.DEFINE_ACC: + lang.ins.append(AssemblyInstruction(Ops.ALU, out, [lang.tor[x] for x in vin], args)) + elif uop == Ops.DEFINE_ACC: reg = lang.newreg(u, dtype=dtype) - lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args)) - elif uop == UOps.SPECIAL: + lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [], args)) + elif uop == Ops.SPECIAL: lang.tor[u] = lang.tor[args] - elif uop == UOps.CONST: - lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(u, dtype=dtype), [], args)) - elif uop == UOps.LOAD: + elif uop == Ops.CONST: + lang.ins.append(AssemblyInstruction(Ops.LOAD, lang.newreg(u, dtype=dtype), [], args)) + elif uop == Ops.LOAD: idx, treg, off = lang.addr_w_offset(args) reg = lang.newreg(u, dtype=dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) if args.valid.min == 0: - lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0)) + lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [], 0)) if args.valid.max == 1: pred = args.valid.render(lang.render_ops, lang) - lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False))) + lang.ins.append(AssemblyInstruction(Ops.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False))) if args.valid.max == 1: # NOTE: you can't compute the index in here, because it assumes it's all available later - lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None))) + lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None))) if args.valid.min == 0 and args.valid.max == 1: - lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}")) + lang.ins.append(AssemblyInstruction(Ops.LABEL, None, [], f"$skipload_{skipload_branch}")) skipload_branch += 1 - elif uop == UOps.STORE: + elif uop == Ops.STORE: if args is None: - lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP)) + lang.ins.append(AssemblyInstruction(Ops.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP)) else: idx, treg, off = lang.addr_w_offset(args) - lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None))) + lang.ins.append(AssemblyInstruction(Ops.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None))) if DEBUG >= 4: for tins in lang.ins: print(tins) diff --git a/extra/assembly/assembly_arm64.py b/extra/assembly/assembly_arm64.py index d165190e43..bca479ab41 100644 --- a/extra/assembly/assembly_arm64.py +++ b/extra/assembly/assembly_arm64.py @@ -3,7 +3,7 @@ from platform import system from typing import Tuple, Dict, List, Optional from tinygrad import dtypes from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps -from tinygrad.codegen.kernel import UOps, UOp +from tinygrad.codegen.kernel import Ops, UOp from tinygrad.helpers import CI from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage @@ -19,7 +19,7 @@ class ARM64Language(AssemblyLanguage): pass def specialize_to_arm64(fn_nm, asm): var_size = 16 - prev_uop:Optional[UOps] = None + prev_uop:Optional[Ops] = None ins = [] x_regs = ['x' + str(i) for i in reversed(range(12))] s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16] @@ -81,7 +81,7 @@ def specialize_to_arm64(fn_nm, asm): ins.append(f"mov x15, {mem_vars[v.nm]}") ins.append(f"ldr {rtor[v.nm]}, [sp, x15]") - if uop == UOps.SPECIAL: + if uop == Ops.SPECIAL: if arg.startswith('data'): # data 8 to n into the stack if int(arg[4:]) >= 8: @@ -90,7 +90,7 @@ def specialize_to_arm64(fn_nm, asm): else: ins.append(f"mov {rtor[out.nm]}, #0") ins.append(f"loop_{arg}:") - elif uop == UOps.CAST: + elif uop == Ops.CAST: if arg == BinaryOps.CMPLT: if rtor[out.nm][0] == 's': mov_imm(0.0, 's0') @@ -102,7 +102,7 @@ def specialize_to_arm64(fn_nm, asm): ins.append(f"csel {rtor[out.nm]}, x15, x14, lt") else: ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}") - elif uop == UOps.ALU: + elif uop == Ops.ALU: if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15') if arg == BinaryOps.MUL and out.dtype == dtypes.bool: ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}") @@ -136,7 +136,7 @@ def specialize_to_arm64(fn_nm, asm): ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}") else: ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}") - elif uop == UOps.LOAD: + elif uop == Ops.LOAD: if arg.__class__ in (int, float): mov_imm(arg, rtor[out.nm]) else: @@ -146,20 +146,20 @@ def specialize_to_arm64(fn_nm, asm): ins.append(f"add x15, {rtor[vin[0].nm]}, x15") ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]") if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}") - elif uop == UOps.STORE: + elif uop == Ops.STORE: #NOTE: if need casting load var in s/h0 or x/w12 temp regs reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm]) if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}") ins.append(f"mov x15, #{arg[0]}") ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]") - elif uop == UOps.COND_BRANCH: + elif uop == Ops.COND_BRANCH: #TODO: this is a hack it shouldn't always be a cmp before a cond branch? - if prev_uop == UOps.LOAD: + if prev_uop == Ops.LOAD: ins.append(f"cmp {rtor[vin[0].nm]}, #0") ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}") - elif uop == UOps.LABEL: + elif uop == Ops.LABEL: ins.append(f"{arg[1:]}:") - elif uop == UOps.ENDLOOP: + elif uop == Ops.ENDLOOP: mov_imm(arg[0], "x15") ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1") ins.append(f"cmp {rtor[vin[0].nm]}, x15") diff --git a/extra/assembly/assembly_ptx.py b/extra/assembly/assembly_ptx.py index 1c71fa691b..7ac1ff5616 100644 --- a/extra/assembly/assembly_ptx.py +++ b/extra/assembly/assembly_ptx.py @@ -1,7 +1,7 @@ from typing import List import struct from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage -from tinygrad.codegen.kernel import UOps, UOp +from tinygrad.codegen.kernel import Ops, UOp from tinygrad import dtypes from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps from tinygrad.runtime.ops_cuda import arch @@ -37,11 +37,11 @@ def specialize_to_ptx(lang, function_name): UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz", TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"} for uop, out, vin, arg in lang.ins: - if uop == UOps.ENDLOOP: + if uop == Ops.ENDLOOP: ins.append("bar.sync 0;") - elif uop == UOps.DEFINE_LOCAL: + elif uop == Ops.DEFINE_LOCAL: ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];") - elif uop == UOps.SPECIAL: + elif uop == Ops.SPECIAL: if arg.startswith('data'): param_cnt += 1 ins.append(f"ld.param.u64 {out}, [{arg}];") @@ -51,7 +51,7 @@ def specialize_to_ptx(lang, function_name): ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};") elif arg.startswith('lid'): ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};") - elif uop == UOps.ALU: + elif uop == Ops.ALU: if arg == BinaryOps.MUL and out.dtype == dtypes.bool: ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};") else: @@ -64,7 +64,7 @@ def specialize_to_ptx(lang, function_name): ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};") vin = vin[1:] + [reg] ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};") - elif uop == UOps.LOAD: + elif uop == Ops.LOAD: if arg.__class__ in (int, float): ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};") elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype): @@ -74,7 +74,7 @@ def specialize_to_ptx(lang, function_name): render_cast(ins, reg, out) else: ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];") - elif uop == UOps.STORE: + elif uop == Ops.STORE: if ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) or arg[2] == dtypes.bool: if arg[2] == dtypes.bool != vin[1].dtype: prereg = lang.newreg((vin[1],'bool'), dtype=dtypes.bool) @@ -85,11 +85,11 @@ def specialize_to_ptx(lang, function_name): ins.append(f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};") else: ins.append(f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};") - elif uop == UOps.CAST: + elif uop == Ops.CAST: render_cast(ins, vin[0], out) - elif uop == UOps.LABEL: + elif uop == Ops.LABEL: ins.append(f"{arg}:") - elif uop == UOps.COND_BRANCH: + elif uop == Ops.COND_BRANCH: ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};") ins_prefix = [".version 7.8", ".target " + arch(), ".address_size 64", diff --git a/extra/assembly/assembly_rdna.py b/extra/assembly/assembly_rdna.py index ad8d36b0cf..cb41ba8f22 100644 --- a/extra/assembly/assembly_rdna.py +++ b/extra/assembly/assembly_rdna.py @@ -2,7 +2,7 @@ import yaml from typing import Tuple, Set, Dict from tinygrad import dtypes from tinygrad.codegen.assembly import AssemblyCodegen, Register -from tinygrad.codegen.kernel import UOps +from tinygrad.codegen.kernel import Ops from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps from tinygrad.runtime.ops_gpu import ROCM_LLVM_PATH @@ -61,7 +61,7 @@ class RDNACodegen(AssemblyCodegen): def reg_out(x): return rtor[x] for uop, out, vin, arg in asm: - if uop == UOps.DEFINE_REGISTER: + if uop == Ops.DEFINE_REGISTER: if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float.vec(4)]: for i in range(arg[2]): # TODO: Re-use gaps created by this to avoid wasting registers @@ -86,7 +86,7 @@ class RDNACodegen(AssemblyCodegen): rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name else: raise NotImplementedError("DEFINE_REGISTER not implemented for arg: ", arg) - elif uop == UOps.SPECIAL: + elif uop == Ops.SPECIAL: if arg.startswith('buf'): i = int(arg[3:]) ins.append(f's_load_b64 {reg_out(out)}, s[0:1], {i*8}') @@ -106,7 +106,7 @@ class RDNACodegen(AssemblyCodegen): pend_regs.clear() ins.append(f'v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}') ins.append(f'v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}') - elif uop == UOps.CONST: + elif uop == Ops.CONST: if arg == float('inf'): arg = "0x7f800000" elif arg == float('-inf'): arg = "0xff800000" if out.dtype == dtypes.float.vec(4): @@ -114,7 +114,7 @@ class RDNACodegen(AssemblyCodegen): ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}") else: ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}") - elif uop == UOps.ALU: + elif uop == Ops.ALU: if arg in [BinaryOps.CMPLT]: ins.append(f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}") else: @@ -127,7 +127,7 @@ class RDNACodegen(AssemblyCodegen): ins.append(f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}") else: ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}") - elif uop == UOps.LOAD: + elif uop == Ops.LOAD: if out.scalar: # swap arg order ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}') @@ -135,13 +135,13 @@ class RDNACodegen(AssemblyCodegen): ins.append(f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}') pend_regs.add(out) for r in out.subregs(): pend_regs.add(r) - elif uop == UOps.STORE: + elif uop == Ops.STORE: ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}') - elif uop == UOps.LABEL: + elif uop == Ops.LABEL: ins.append(f"{arg}:") - elif uop == UOps.COND_BRANCH: + elif uop == Ops.COND_BRANCH: ins.append(f"s_cbranch_scc{'1' if arg[1] else '0'} {arg[0]}") - elif uop == UOps.CAST: + elif uop == Ops.CAST: if vin[0].dtype == dtypes.bool: if out.dtype == dtypes.float32: ins.append(f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}") diff --git a/extra/backends/triton.py b/extra/backends/triton.py index a10f172e8f..9da79b108c 100644 --- a/extra/backends/triton.py +++ b/extra/backends/triton.py @@ -2,7 +2,7 @@ from typing import Dict, List, Final, Callable, DefaultDict from collections import defaultdict from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Op from tinygrad.helpers import DType, PtrDType, dtypes, ImageDType, DEBUG, getenv -from tinygrad.codegen.kernel import UOp, UOps +from tinygrad.codegen.kernel import UOp, Ops from triton.compiler import compile as triton_compile import linecache import math @@ -75,32 +75,32 @@ def uops_to_triton(function_name:str, uops:List[UOp]): def int_div(x,y): return f"({x}//{y})" if y != '0' else f"{x}*tl.where({x}==0, float('nan'), float('inf'))" for u in uops: uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg - if uop == UOps.LOOP: + if uop == Ops.LOOP: kk(f"for {ssa(u, 'ridx')} in range({vin[0].arg}, {r[vin[1]]}):") depth += 1 - elif uop == UOps.END: depth -= 1 - elif uop == UOps.ALU: + elif uop == Ops.END: depth -= 1 + elif uop == Ops.ALU: assert dtype is not None val = code_for_op[args](*[r[x] for x in vin]) if child_count[u] <=1 or dtypes.is_int(dtype): r[u] = int_div(*[r[x] for x in vin]) if args == BinaryOps.DIV and dtypes.is_int(dtype) else val else: kk(f"{ssa(u, 'alu')} = ({val})") - elif uop == UOps.LOAD: + elif uop == Ops.LOAD: assert dtype is not None if len(vin) == 2: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.load({r[vin[0]]} + { fill_dims_for_idx(r[vin[1]], dims)}, mask = {render_valid(valid)})', dtype)}") else: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.where({r[vin[2]]}, tl.load({r[vin[0]]}+{fill_dims_for_idx(r[vin[1]],dims)} , mask={render_valid(valid+[r[vin[2]]])}), 0.0)', dtype)}") - elif uop == UOps.DEFINE_ACC: kk(f"{ssa(u, 'acc')} = {define_scalar(local_size, dtype, args).replace('//', '/')}") - elif uop == UOps.CONST: r[u] = define_scalar([], dtype, args) - elif uop == UOps.ASSIGN: + elif uop == Ops.DEFINE_ACC: kk(f"{ssa(u, 'acc')} = {define_scalar(local_size, dtype, args).replace('//', '/')}") + elif uop == Ops.CONST: r[u] = define_scalar([], dtype, args) + elif uop == Ops.ASSIGN: kk(f"{r[vin[0]]} = {r[vin[1]].replace('//', '/')}") r[u] = r[vin[0]] - elif uop == UOps.STORE: + elif uop == Ops.STORE: assert not isinstance(dtype, ImageDType), "unimplemented: image store" kk(f"{'if '+r[vin[3]]+': ' if len(vin)>3 else ''}tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ") - elif uop == UOps.DEFINE_GLOBAL: + elif uop == Ops.DEFINE_GLOBAL: bufs.append(args) signatures.append("*" if isinstance(dtype, PtrDType) else "" + signature_dtypes[dtype]) r[u] = args - elif uop == UOps.SPECIAL: + elif uop == Ops.SPECIAL: dims.append(args[1]) valid.append(f"{args[1]}<{get_max(args[2])}") if args[1].startswith("g"): kk(f"{args[1]} = tl.program_id({args[0]}) # {args[2]}") @@ -108,7 +108,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]): kk(f"{args[1]} = tl.arange({0}, {next_power_of_2(args[2])})") local_size.append(args[2]) r[u] = args[1] - elif uop == UOps.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype, isinstance(args, tuple) and args[1]) + elif uop == Ops.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype, isinstance(args, tuple) and args[1]) else: raise NotImplementedError(f"unimplemented: {uop}") prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(bufs)+"):\n" diff --git a/extra/optimization/helpers.py b/extra/optimization/helpers.py index a102a9a8e2..88f6f25b9d 100644 --- a/extra/optimization/helpers.py +++ b/extra/optimization/helpers.py @@ -2,11 +2,12 @@ from typing import Tuple from tinygrad import Variable from tinygrad.codegen.kernel import Opt, OptOps -from tinygrad.ops import UOp, UOps, KernelInfo, TernaryOps, BinaryOps, UnaryOps, ReduceOps, MetaOps +from tinygrad.ops import UOp, Ops, KernelInfo, TernaryOps, BinaryOps, UnaryOps, ReduceOps, MetaOps from tinygrad.dtype import dtypes, PtrDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View inf, nan = float('inf'), float('nan') +UOps = Ops # kernel unpacker from tinygrad.codegen.kernel import Kernel diff --git a/extra/to_movement_ops.py b/extra/to_movement_ops.py index af9da9b88c..2528a2ca8f 100644 --- a/extra/to_movement_ops.py +++ b/extra/to_movement_ops.py @@ -4,7 +4,7 @@ from collections import defaultdict from typing import List, Tuple, DefaultDict from extra.optimization.helpers import load_worlds, ast_str_to_ast from tinygrad.helpers import prod, tqdm -from tinygrad.ops import UOp, UOps +from tinygrad.ops import UOp, Ops from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.ops import sym_infer, Node @@ -136,7 +136,7 @@ def test_rebuild(st: ShapeTracker): assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}" def test_rebuild_bufferop_st(ast:UOp): - if ast.op is UOps.SHAPETRACKER: + if ast.op is Ops.SHAPETRACKER: test_rebuild(ast.arg) for src in ast.src: test_rebuild_bufferop_st(src) diff --git a/test/external/external_benchmark_schedule.py b/test/external/external_benchmark_schedule.py index 98e071ca40..66b7592da5 100644 --- a/test/external/external_benchmark_schedule.py +++ b/test/external/external_benchmark_schedule.py @@ -2,7 +2,7 @@ from typing import List from extra.models.resnet import ResNet50 from tinygrad import Tensor, Device, nn from tinygrad.helpers import Profiling, Timing, getenv, BEAM, NOOPT, DEBUG, Context, ansilen -from tinygrad.ops import UOps +from tinygrad.ops import Ops from tinygrad.codegen.kernel import Kernel from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index from tinygrad.codegen.linearize import linearize_uop @@ -27,7 +27,7 @@ if __name__ == "__main__": sched = out.schedule() if not SCHEDULE_ONLY: - asts = list({x.ast.key:x.ast for x in sched if x.ast.op is UOps.SINK}.values()) + asts = list({x.ast.key:x.ast for x in sched if x.ast.op is Ops.SINK}.values()) if (restrict_kernel := getenv("RESTRICT_KERNEL", -1)) != -1: asts = asts[restrict_kernel:restrict_kernel+1] kernels: List[Kernel] = [] with Timing(f"***** model opts({len(asts):2d}) in "): diff --git a/test/external/external_test_valid_remove.py b/test/external/external_test_valid_remove.py index 4deca0071f..e9e5ed066c 100644 --- a/test/external/external_test_valid_remove.py +++ b/test/external/external_test_valid_remove.py @@ -2,7 +2,7 @@ import unittest from tinygrad import Device -from tinygrad.ops import UOp, UOps, BinaryOps, UnaryOps +from tinygrad.ops import UOp, Ops, BinaryOps, UnaryOps from tinygrad.engine.search import Opt, OptOps from tinygrad.dtype import dtypes from tinygrad.shape.shapetracker import ShapeTracker @@ -13,40 +13,40 @@ class TestOpenpilotValidhack(unittest.TestCase): def test_valid_removal(self): Device.DEFAULT = "GPU" - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((64, 1024, 4)), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 4096, 32, 0, 0, 4, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=( - x5:=UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8, 9, 10)), src=( - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((128, 768, 4)), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 3, 1, 4, 4, 130, 4, 258), strides=(0, 0, 0, 0, 0, 4, 0, 1, 0, 3072, 0, 12), offset=-3084, mask=((0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 3), (0, 1), (0, 4), (0, 4), (1, 129), (0, 4), (1, 257)), contiguous=False), View(shape=(1, 64, 128, 1, 1, 8, 4, 3, 4, 3, 3), strides=(0, 2064, 2, 0, 0, 0, 0, 2146560, 536640, 135192, 259), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((8, 108, 4)), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 3, 4, 3, 3), strides=(0, 0, 0, 0, 0, 432, 1, 48, 4, 144, 16), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 4, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - x19:=UOp(UOps.CONST, dtypes.float, arg=0.0, src=( - x20:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.CONST, dtypes.float, arg=1.0, src=( + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((64, 1024, 4)), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 4096, 32, 0, 0, 4, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=( + x5:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8, 9, 10)), src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((128, 768, 4)), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 3, 1, 4, 4, 130, 4, 258), strides=(0, 0, 0, 0, 0, 4, 0, 1, 0, 3072, 0, 12), offset=-3084, mask=((0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 3), (0, 1), (0, 4), (0, 4), (1, 129), (0, 4), (1, 257)), contiguous=False), View(shape=(1, 64, 128, 1, 1, 8, 4, 3, 4, 3, 3), strides=(0, 2064, 2, 0, 0, 0, 0, 2146560, 536640, 135192, 259), offset=0, mask=None, contiguous=False))), src=()),)), + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((8, 108, 4)), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 3, 4, 3, 3), strides=(0, 0, 0, 0, 0, 432, 1, 48, 4, 144, 16), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 4, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + x19:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=( + x20:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.CONST, dtypes.float, arg=1.0, src=( x20,)), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=UnaryOps.EXP2, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=UnaryOps.EXP2, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( x5, - UOp(UOps.CONST, dtypes.float, arg=1.4426950408889634, src=( + UOp(Ops.CONST, dtypes.float, arg=1.4426950408889634, src=( x20,)),)),)), - x29:=UOp(UOps.CONST, dtypes.float, arg=-1.0, src=( + x29:=UOp(Ops.CONST, dtypes.float, arg=-1.0, src=( x20,)),)),)), x19,)), x29,)),)),)),)) @@ -62,51 +62,51 @@ class TestOpenpilotValidhack(unittest.TestCase): def test_const_idx(self): Device.DEFAULT = "GPU" - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((10, 128, 4)), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 512, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((1, 128, 4)), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=0, mask=((0, 1), (0, 1), (0, 512)), contiguous=False),)), src=()),)), - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - x18:=UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=48128, mask=((0, 1), (1, 2), (0, 512)), contiguous=False),)), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((10, 128, 4)), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 512, 1), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((1, 128, 4)), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=0, mask=((0, 1), (0, 1), (0, 512)), contiguous=False),)), src=()),)), + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + x18:=UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=48128, mask=((0, 1), (1, 2), (0, 512)), contiguous=False),)), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( x18, - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=45568, mask=((0, 1), (2, 3), (0, 512)), contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=45568, mask=((0, 1), (2, 3), (0, 512)), contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( x18, - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=43008, mask=((0, 1), (3, 4), (0, 512)), contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=43008, mask=((0, 1), (3, 4), (0, 512)), contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( x18, - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=40448, mask=((0, 1), (4, 5), (0, 512)), contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=40448, mask=((0, 1), (4, 5), (0, 512)), contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( x18, - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=37888, mask=((0, 1), (5, 6), (0, 512)), contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=37888, mask=((0, 1), (5, 6), (0, 512)), contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( x18, - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=35328, mask=((0, 1), (6, 7), (0, 512)), contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=35328, mask=((0, 1), (6, 7), (0, 512)), contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( x18, - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=32768, mask=((0, 1), (7, 8), (0, 512)), contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=32768, mask=((0, 1), (7, 8), (0, 512)), contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( x18, - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=30208, mask=((0, 1), (8, 9), (0, 512)), contiguous=False),)), src=()),)),)),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((1, 128, 4)), arg=3, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=0, mask=((0, 1), (9, 10), (0, 512)), contiguous=False),)), src=()),)),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=30208, mask=((0, 1), (8, 9), (0, 512)), contiguous=False),)), src=()),)),)),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((1, 128, 4)), arg=3, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=0, mask=((0, 1), (9, 10), (0, 512)), contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.NOLOCALS, axis=None, amt=None)] kernel = Kernel(ast) diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 4b72435d61..eecb750e33 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -25,7 +25,7 @@ from tinygrad.codegen.kernel import Opt, OptOps from tinygrad.engine.search import get_kernel_actions, bufs_from_lin from tinygrad.engine.realize import CompiledRunner from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG, Timing -from tinygrad.ops import UnaryOps, UOp, UOps +from tinygrad.ops import UnaryOps, UOp, Ops from test.helpers import is_dtype_supported def on_linearizer_will_run(): pass @@ -252,7 +252,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2, opts_list=None): def _is_simple(lin: Kernel) -> bool: if len(lin.ast.src) > 1: return False ast:UOp = lin.ast.src[0] - if ast.src[0].arg is UnaryOps.CAST and ast.src[0].src[0].op is UOps.LOAD: return True + if ast.src[0].arg is UnaryOps.CAST and ast.src[0].src[0].op is Ops.LOAD: return True return False if __name__ == "__main__": diff --git a/test/helpers.py b/test/helpers.py index d66849a6a2..a3990b18a8 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -2,7 +2,7 @@ import sys, time, logging, difflib from typing import Callable, Optional, Tuple import numpy as np from tinygrad import Tensor, Device, dtypes -from tinygrad.ops import UOp, UOps, sint +from tinygrad.ops import UOp, Ops, sint from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.tensor import _to_np_dtype from tinygrad.engine.realize import Runner @@ -69,7 +69,7 @@ def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)): def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional[ShapeTracker]=None, st_src:Optional[Tuple[UOp]]=None) -> UOp: if st_src is None: st_src = (st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),) - return UOp(UOps.VALID, dtypes.bool, st_src).where(UOp.const(dtype, val), UOp.const(dtype, 0)) + return UOp(Ops.VALID, dtypes.bool, st_src).where(UOp.const(dtype, val), UOp.const(dtype, 0)) def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]: st = time.perf_counter_ns() @@ -77,7 +77,7 @@ def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]: return ret, (time.perf_counter_ns()-st)*1e-6 def eval_uop(uop:UOp): - g = UOp(UOps.DEFINE_GLOBAL, uop.dtype.ptr(), arg=0, src=()) + g = UOp(Ops.DEFINE_GLOBAL, uop.dtype.ptr(), arg=0, src=()) rw = full_graph_rewrite(UOp.store(g, UOp.const(dtypes.int, 0), uop).sink(), PythonRenderer) prog = PythonProgram("run", PythonCompiler().compile(PythonRenderer().render("run", linearize_uop(rw)))) buf = PythonAllocator().alloc(uop.dtype.itemsize) diff --git a/test/test_const_folding.py b/test/test_const_folding.py index 8dd7daf0e4..9fa559e20d 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -1,6 +1,6 @@ import unittest, math from tinygrad import Tensor, Device, dtypes -from tinygrad.ops import UOps +from tinygrad.ops import Ops from tinygrad.engine.schedule import create_schedule from tinygrad.helpers import CI import numpy as np @@ -9,7 +9,7 @@ from test.helpers import is_dtype_supported def _check_ast_count(desired_count:int, t:Tensor): # NOTE: this has side effect because everything can be scheduled only once schedule = create_schedule(t.lazydata.lbs) - asts = [s for s in schedule if s.ast.op is UOps.SINK] + asts = [s for s in schedule if s.ast.op is Ops.SINK] assert len(asts) == desired_count class TestUnaryOpsConstFolding(unittest.TestCase): diff --git a/test/test_conv_shapetracker.py b/test/test_conv_shapetracker.py index cb186a6d9b..4333861a38 100644 --- a/test/test_conv_shapetracker.py +++ b/test/test_conv_shapetracker.py @@ -1,6 +1,6 @@ #!/usr/bin/env python import unittest -from tinygrad.ops import UOps +from tinygrad.ops import Ops from tinygrad.tensor import Tensor from tinygrad.nn import Conv2d from tinygrad.engine.schedule import create_schedule @@ -15,9 +15,9 @@ class TestConvShapetracker(unittest.TestCase): # first run to init the weights, they are scheduled. create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata]) # run it again to get the kernels - sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata]) if si.ast.op is UOps.SINK] + sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata]) if si.ast.op is Ops.SINK] assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}" - for st in [x.st_arg for x in sched[0].ast.parents if x.op is UOps.LOAD]: + for st in [x.st_arg for x in sched[0].ast.parents if x.op is Ops.LOAD]: assert len(st.views) == 1 def test_conv_2x2_backward_one_view(self): @@ -26,7 +26,7 @@ class TestConvShapetracker(unittest.TestCase): conv(X).mean().backward() si = X.grad.schedule()[-1] print(si) - ldb = [x for x in si.ast.parents if x.op is UOps.LOAD][0] + ldb = [x for x in si.ast.parents if x.op is Ops.LOAD][0] st: ShapeTracker = ldb.st_arg.simplify() # NOTE: st.real_size() is broken print(si.inputs[0].size) diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py index 279f0456ca..d00e37c639 100644 --- a/test/test_dtype_alu.py +++ b/test/test_dtype_alu.py @@ -8,7 +8,7 @@ from tinygrad.dtype import DType from tinygrad.helpers import CI, getenv from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import run_schedule -from tinygrad.ops import UnaryOps, UOps +from tinygrad.ops import UnaryOps, Ops from tinygrad.tensor import _to_np_dtype from test.helpers import is_dtype_supported import pytest @@ -79,7 +79,7 @@ def universal_test_unary(a, dtype, op): np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2) else: np.testing.assert_equal(tensor_value, numpy_value) if op[0] != Tensor.reciprocal: # reciprocal is not supported in most backends - op = [x for x in ast.parents if x.op is UOps.ALU and x.arg in UnaryOps][0] + op = [x for x in ast.parents if x.op is Ops.ALU and x.arg in UnaryOps][0] assert op.dtype == dtype def universal_test_cast(a, in_dtype, dtype): diff --git a/test/test_lazybuffer.py b/test/test_lazybuffer.py index 1ca672c92f..168720e2d2 100644 --- a/test/test_lazybuffer.py +++ b/test/test_lazybuffer.py @@ -2,7 +2,7 @@ import numpy as np import unittest from tinygrad import Tensor, Device, dtypes -from tinygrad.ops import UOps +from tinygrad.ops import Ops from tinygrad.engine.lazy import LazyBuffer, MetaOps from tinygrad.engine.schedule import create_schedule @@ -75,7 +75,7 @@ class TestReduceOp(unittest.TestCase): a = a.sum() sched = create_schedule([a.lazydata]) assert len(sched) == 1 - self.assertIs(sched[0].ast.src[0].src[2].op, UOps.REDUCE_AXIS) + self.assertIs(sched[0].ast.src[0].src[2].op, Ops.REDUCE_AXIS) def test_split_reduce_kernel_dim0(self): a = Tensor.rand(256, 255).realize() @@ -83,7 +83,7 @@ class TestReduceOp(unittest.TestCase): sched = create_schedule([a.lazydata]) assert len(sched) == 2 for s in sched: - self.assertIs(s.ast.src[0].src[2].op, UOps.REDUCE_AXIS) + self.assertIs(s.ast.src[0].src[2].op, Ops.REDUCE_AXIS) def test_split_reduce_kernel_dim1(self): a = Tensor.rand(255, 256).realize() @@ -91,7 +91,7 @@ class TestReduceOp(unittest.TestCase): sched = create_schedule([a.lazydata]) assert len(sched) == 2 for s in sched: - self.assertIs(s.ast.src[0].src[2].op, UOps.REDUCE_AXIS) + self.assertIs(s.ast.src[0].src[2].op, Ops.REDUCE_AXIS) class TestView(unittest.TestCase): def test_all_masked_out(self): diff --git a/test/test_linearizer.py b/test/test_linearizer.py index f8af1eaf20..940b65ca2c 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -6,7 +6,7 @@ from dataclasses import replace from test.helpers import ast_const from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, Kernel from tinygrad.codegen.lowerer import get_grouped_dims -from tinygrad.ops import UOp, UOps, BinaryOps, TernaryOps, UnaryOps +from tinygrad.ops import UOp, Ops, BinaryOps, TernaryOps, UnaryOps from tinygrad.device import Device, Buffer from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View @@ -37,7 +37,7 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi k = Kernel(realized_ast) k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt) k.linearize() - assert len([uop for uop in k.uops if uop.op is UOps.WMMA]) > 0, "tensor core not triggered" + assert len([uop for uop in k.uops if uop.op is Ops.WMMA]) > 0, "tensor core not triggered" assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included" np_c = np_a @ np_b if dtype_out == dtypes.half: tc_atol, tc_rtol = 1e-2, 1e-3 @@ -53,7 +53,7 @@ def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, d k = Kernel(realized_ast) k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt) k.linearize() - wmmas = len([uop for uop in k.uops if uop.op is UOps.WMMA]) + wmmas = len([uop for uop in k.uops if uop.op is Ops.WMMA]) tcs = len([x for x in k.applied_opts if x.op is OptOps.TC]) if ensure_triggered: assert wmmas > 0, "tensor core not triggered" @@ -84,19 +84,19 @@ class TestLinearizer(unittest.TestCase): def test_multioutput(self): dtype, st = dtypes.int, ShapeTracker.from_shape((8,)) - g0, g1, g2, g3 = [UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), arg=i) for i in range(4)] - a = UOp(UOps.LOAD, dtype, (g2, st.to_uop())) - b = UOp(UOps.LOAD, dtype, (g3, st.to_uop())) - out0 = UOp(UOps.STORE, dtypes.void, (g0, st.to_uop(), a + b)) - out1 = UOp(UOps.STORE, dtypes.void, (g1, st.to_uop(), a * b)) - sink = UOp(UOps.SINK, src=(out0, out1)) + g0, g1, g2, g3 = [UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), arg=i) for i in range(4)] + a = UOp(Ops.LOAD, dtype, (g2, st.to_uop())) + b = UOp(Ops.LOAD, dtype, (g3, st.to_uop())) + out0 = UOp(Ops.STORE, dtypes.void, (g0, st.to_uop(), a + b)) + out1 = UOp(Ops.STORE, dtypes.void, (g1, st.to_uop(), a * b)) + sink = UOp(Ops.SINK, src=(out0, out1)) a_t = Tensor.full(st.shape, 2).contiguous().realize() b_t = Tensor.full(st.shape, 3).contiguous().realize() lin = helper_linearizer_ast(sink, [a_t, b_t], wanna_output=[a_t.numpy()+b_t.numpy(), a_t.numpy()*b_t.numpy()])[0] - stores = [u for u in lin.uops if u.op is UOps.STORE] - mutable_bufs = dedup(flatten([[x for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL] for u in stores])) + stores = [u for u in lin.uops if u.op is Ops.STORE] + mutable_bufs = dedup(flatten([[x for x in u.src[0].sparents if x.op is Ops.DEFINE_GLOBAL] for u in stores])) assert len(mutable_bufs) == len(stores) == 2 assert [u.arg for u in mutable_bufs] == [0, 1] @@ -107,14 +107,14 @@ class TestLinearizer(unittest.TestCase): Tensor.manual_seed(0) x = Tensor.randn(32, dtype=dtypes.float).realize() st_x = x.lazydata.st - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - first_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((1, 32)).expand((32, 32)).to_uop())) - first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (1,))) - second_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((32, 1)).to_uop())) + g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] + first_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((1, 32)).expand((32, 32)).to_uop())) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (1,))) + second_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((32, 1)).to_uop())) diff = second_x + first_reduce*ast_const(dtypes.float, -1, (32, 1)) - second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (0,))) - store = UOp(UOps.STORE, dtypes.void, (g0, ShapeTracker.from_shape((1, 1)).to_uop(), second_reduce)) - sink = UOp(UOps.SINK, src=(store,)) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (0,))) + store = UOp(Ops.STORE, dtypes.void, (g0, ShapeTracker.from_shape((1, 1)).to_uop(), second_reduce)) + sink = UOp(Ops.SINK, src=(store,)) opts = [ [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], # grouping [Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 8)], @@ -131,7 +131,7 @@ class TestLinearizer(unittest.TestCase): wanna_output = (x.numpy()-x.numpy().sum(-1, keepdims=True)).sum(-1).reshape(1,1) lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts) for l in lins: - ranges = [u.op for u in l.uops if (u.op is UOps.RANGE and u.arg[1]) or (u.op is UOps.ENDRANGE and u.src[0].arg[1])] + ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])] for i,u in enumerate(ranges): if i == 0: continue assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}" @@ -143,14 +143,14 @@ class TestLinearizer(unittest.TestCase): Tensor.manual_seed(0) x = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize() st_x = x.lazydata.st - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - first_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((27, 1, 32, 5)).expand((27, 32, 32, 5)).to_uop())) - first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) - second_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((27, 32, 1, 5)).to_uop())) + g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] + first_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((27, 1, 32, 5)).expand((27, 32, 32, 5)).to_uop())) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + second_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((27, 32, 1, 5)).to_uop())) diff = second_x + first_reduce*ast_const(dtypes.float, -1, (27, 32, 1, 5)) - second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) - store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) - sink = UOp(UOps.SINK, src=(store,)) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) + sink = UOp(Ops.SINK, src=(store,)) opts = [ # locals [Opt(OptOps.LOCAL, 0, 3)], @@ -195,7 +195,7 @@ class TestLinearizer(unittest.TestCase): wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5) lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts) for l in lins: - ranges = [u.op for u in l.uops if (u.op is UOps.RANGE and u.arg[1]) or (u.op is UOps.ENDRANGE and u.src[0].arg[1])] + ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])] for i,u in enumerate(ranges): if i == 0: continue assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}" @@ -205,21 +205,21 @@ class TestLinearizer(unittest.TestCase): x0 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize() x1 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize() x2 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize() - g0, g1, g2, g3 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(4)] - first_x = UOp(UOps.LOAD, dtypes.float, (g1, x0.lazydata.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5)).to_uop())) - first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) - second_x = UOp(UOps.LOAD, dtypes.float, (g2, x1.lazydata.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5)).to_uop())) + g0, g1, g2, g3 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(4)] + first_x = UOp(Ops.LOAD, dtypes.float, (g1, x0.lazydata.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5)).to_uop())) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) + second_x = UOp(Ops.LOAD, dtypes.float, (g2, x1.lazydata.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 32, 32, 1, 5))) - second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (2,))) - third_x = UOp(UOps.LOAD, dtypes.float, (g3, x2.lazydata.st.reshape((27, 32, 1, 1, 5)).to_uop())) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (2,))) + third_x = UOp(Ops.LOAD, dtypes.float, (g3, x2.lazydata.st.reshape((27, 32, 1, 1, 5)).to_uop())) mul = (third_x*second_reduce) - third_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (mul,), (BinaryOps.ADD, (1,))) - store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 1, 5)).to_uop(), third_reduce)) - sink = UOp(UOps.SINK, src=(store,)) + third_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (mul,), (BinaryOps.ADD, (1,))) + store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 1, 5)).to_uop(), third_reduce)) + sink = UOp(Ops.SINK, src=(store,)) wanna_output = (x2.numpy()*(x1.numpy()-x0.numpy().sum(axis=1, keepdims=True)).sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,1,5) lins = helper_linearizer_ast(sink, [x0,x1,x2], wanna_output=[wanna_output]) for l in lins: - ranges = [u.op for u in l.uops if (u.op is UOps.RANGE and u.arg[1]) or (u.op is UOps.ENDRANGE and u.src[0].arg[1])] + ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])] for i,u in enumerate(ranges): if i == 0: continue assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}" @@ -232,15 +232,15 @@ class TestLinearizer(unittest.TestCase): Tensor.manual_seed(0) x = Tensor.randn(8, 32, 8, 16, dtype=dtypes.float).realize() st = x.lazydata.st - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - first_x = UOp(UOps.LOAD, dtypes.float, (g1, st.reshape((8, 1, 32, 8, 1, 16)).expand((8, 32, 32, 8, 16, 16)).to_uop())) - first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2, 5))) - second_x = UOp(UOps.LOAD, dtypes.float, (g1, st.reshape((8, 32, 1, 8, 16, 1)).to_uop())) + g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] + first_x = UOp(Ops.LOAD, dtypes.float, (g1, st.reshape((8, 1, 32, 8, 1, 16)).expand((8, 32, 32, 8, 16, 16)).to_uop())) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2, 5))) + second_x = UOp(Ops.LOAD, dtypes.float, (g1, st.reshape((8, 32, 1, 8, 16, 1)).to_uop())) neg_first_reduce = first_reduce * ast_const(dtypes.float, -1, (8, 32, 1, 8, 16, 1)) squares = (second_x+neg_first_reduce) - squares_sum = UOp(UOps.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (1, 4))) - store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((8, 1, 1, 8, 1, 1)).to_uop(), squares_sum,)) - sink = UOp(UOps.SINK, src=(store,)) + squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (1, 4))) + store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((8, 1, 1, 8, 1, 1)).to_uop(), squares_sum,)) + sink = UOp(Ops.SINK, src=(store,)) wanna_output = (x.numpy()-x.numpy().sum(axis=(1,3), keepdims=True)).sum(axis=(1,3)).reshape((8,1,1,8,1,1)) opts = [ # openCL / GPU=1 is 256 max threads @@ -271,7 +271,7 @@ class TestLinearizer(unittest.TestCase): ] lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts) for l in lins: - ranges = [u.op for u in l.uops if (u.op is UOps.RANGE and u.arg[1]) or (u.op is UOps.ENDRANGE and u.src[0].arg[1])] + ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])] for i,u in enumerate(ranges): if i < 2: continue assert ranges[i-2] != u or ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-2], ranges[i-1], {u}}" @@ -283,14 +283,14 @@ class TestLinearizer(unittest.TestCase): # check how it works with one reduce optimized and one unoptimized Tensor.manual_seed(0) x = Tensor.randn(27, 15, 5, dtype=dtypes.float).softmax(1).realize() - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop())) - first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) - second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop())) + g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] + first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop())) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5))) - second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) - store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) - sink = UOp(UOps.SINK, src=(store,)) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) + sink = UOp(Ops.SINK, src=(store,)) opts = [ [Opt(OptOps.GROUPTOP, 0, 3)], # grouping [Opt(OptOps.GROUPTOP, 1, 3)], @@ -302,7 +302,7 @@ class TestLinearizer(unittest.TestCase): wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5) lins = helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts) for l in lins: - ranges = [u.op for u in l.uops if (u.op is UOps.RANGE and u.arg[1]) or (u.op is UOps.ENDRANGE and u.src[0].arg[1])] + ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])] for i,u in enumerate(ranges): if i == 0: continue assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}" @@ -314,16 +314,16 @@ class TestLinearizer(unittest.TestCase): Tensor.manual_seed(0) x = Tensor.randn(4, 32, dtype=dtypes.float).realize() x_p = Tensor.randn(4, 32, dtype=dtypes.float).realize() - g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] - first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop())) - first_x_p = UOp(UOps.LOAD, dtypes.float, (g2, x_p.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop())) - first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) - first_reduce_p = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x_p.alu(UnaryOps.EXP2),), (BinaryOps.ADD, (2,))) - second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1)).to_uop())) + g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] + first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop())) + first_x_p = UOp(Ops.LOAD, dtypes.float, (g2, x_p.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop())) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + first_reduce_p = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x_p.alu(UnaryOps.EXP2),), (BinaryOps.ADD, (2,))) + second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1)).to_uop())) diff = (second_x+(first_reduce + first_reduce_p)*ast_const(dtypes.float, -1, (4, 32, 1))) - second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) - store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((4, 1, 1)).to_uop(), second_reduce)) - sink = UOp(UOps.SINK, src=(store,)) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((4, 1, 1)).to_uop(), second_reduce)) + sink = UOp(Ops.SINK, src=(store,)) opts = [ # [Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)], # grouping # [Opt(OptOps.GROUPTOP, 0, 8), Opt(OptOps.GROUPTOP, 1, 8)], @@ -340,7 +340,7 @@ class TestLinearizer(unittest.TestCase): wanna_output = (x.numpy()-(x.numpy().sum(-1, keepdims=True)+np.exp2(x_p.numpy()).sum(-1, keepdims=True))).sum(-1).reshape(4, 1,1) lins = helper_linearizer_ast(sink, [x,x_p], wanna_output=[wanna_output], opts=opts) for l in lins: - ranges = [u.op for u in l.uops if (u.op is UOps.RANGE and u.arg[1]) or (u.op is UOps.ENDRANGE and u.src[0].arg[1])] + ranges = [u.op for u in l.uops if (u.op is Ops.RANGE and u.arg[1]) or (u.op is Ops.ENDRANGE and u.src[0].arg[1])] for i,u in enumerate(ranges): if i == 0: continue assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}" @@ -350,16 +350,16 @@ class TestLinearizer(unittest.TestCase): # check how multireduce works with multioutput Tensor.manual_seed(0) x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize() - g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] - first_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop())) - first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) - second_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop())) + g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] + first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop())) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5))) - second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) - store0 = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + store0 = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) second_out = second_reduce * ast_const(dtypes.float, 1/15, (27, 1, 1, 5)) - store1 = UOp(UOps.STORE, src=(g1, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_out)) - sink = UOp(UOps.SINK, src=(store0, store1)) + store1 = UOp(Ops.STORE, src=(g1, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_out)) + sink = UOp(Ops.SINK, src=(store0, store1)) wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5) helper_linearizer_ast(sink, [x], wanna_output=[wanna_output, wanna_output/15]) @@ -373,18 +373,18 @@ class TestLinearizer(unittest.TestCase): # if we change the shape of store1 to be contiguous, it will match store0 but not the value it's storing (FAIL!) Tensor.manual_seed(0) x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize() - g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] - first_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop())) - first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) - second_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop())) + g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] + first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop())) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5))) - second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) - store0 = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) - store1 = UOp(UOps.STORE, src=(g1, ShapeTracker(views=(View(shape=(27,15,1,5), strides=(5,0,1,1), offset=0, mask=None, contiguous=False),)).to_uop(), first_reduce)) # noqa: E501 + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + store0 = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) + store1 = UOp(Ops.STORE, src=(g1, ShapeTracker(views=(View(shape=(27,15,1,5), strides=(5,0,1,1), offset=0, mask=None, contiguous=False),)).to_uop(), first_reduce)) # noqa: E501 wanna_output0 = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5) wanna_output1 = x.numpy().sum(axis=1).reshape(27,1,1,5) - ast = UOp(UOps.SINK, src=(store0, store1)) + ast = UOp(Ops.SINK, src=(store0, store1)) k = Kernel(ast) prg = CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT)) inbufs = [x.lazydata.base.buffer] @@ -397,14 +397,14 @@ class TestLinearizer(unittest.TestCase): def test_complete_unroll_multireduce(self): Tensor.manual_seed(0) x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize() - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop())) - first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) - second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop())) + g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] + first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop())) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5))) - second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) - store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) - sink = UOp(UOps.SINK, src=(store,)) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) + sink = UOp(Ops.SINK, src=(store,)) opts = [[Opt(OptOps.UNROLL, 0, 3), Opt(OptOps.UNROLL, 0, 3)]] wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5) helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts) @@ -413,14 +413,14 @@ class TestLinearizer(unittest.TestCase): def test_upcast_multireduce(self): Tensor.manual_seed(0) x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize() - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop())) - first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) - second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop())) + g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] + first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop())) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5))) - second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) - store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) - sink = UOp(UOps.SINK, src=(store,)) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) + sink = UOp(Ops.SINK, src=(store,)) opts = [[Opt(OptOps.UPCAST, 0, 3)]] wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5) helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts) @@ -432,14 +432,14 @@ class TestLinearizer(unittest.TestCase): # make sure the if block of a grouped reduce can be closed early and the result loaded back in Tensor.manual_seed(0) x = Tensor.randn(27, 12, 5, dtype=dtypes.float).realize() - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5)).to_uop())) - first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) - second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 12, 1, 5)).to_uop())) + g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] + first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5)).to_uop())) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 12, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 12, 1, 5))) - second_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) - store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) - sink = UOp(UOps.SINK, src=(store,)) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) + sink = UOp(Ops.SINK, src=(store,)) opts = [[Opt(OptOps.GROUPTOP, 0, 3), Opt(OptOps.GROUPTOP, 1, 3)]] wanna_output = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5) helper_linearizer_ast(sink, [x], wanna_output=[wanna_output], opts=opts) @@ -448,17 +448,17 @@ class TestLinearizer(unittest.TestCase): def test_mean_std_multireduce(self): Tensor.manual_seed(0) x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop())) - first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) + g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] + first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop())) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1)) - second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 35, 1)).to_uop())) + second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 35, 1)).to_uop())) squares = (second_x+neg_mean)*(second_x+neg_mean) - squares_sum = UOp(UOps.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (2,))) + squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (2,))) variance = squares_sum * ast_const(dtypes.float, 1/35, (15, 25, 1, 1)) std = variance.alu(UnaryOps.SQRT) - store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), std)) - sink = UOp(UOps.SINK, src=(store,)) + store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), std)) + sink = UOp(Ops.SINK, src=(store,)) wanna_output = x.numpy().std(axis=2, ddof=0).reshape((15,25,1,1)) helper_linearizer_ast(sink, [x], wanna_output=[wanna_output]) @@ -466,17 +466,17 @@ class TestLinearizer(unittest.TestCase): def test_mean_std_multireduce_mid_dim(self): Tensor.manual_seed(0) x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35)).to_uop())) - first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] + first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35)).to_uop())) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) neg_mean = first_reduce * ast_const(dtypes.float, -0.04, (15, 25, 1, 35)) - second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).to_uop())) + second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).to_uop())) squares = (second_x+neg_mean)*(second_x+neg_mean) - squares_sum = UOp(UOps.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (1,))) + squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (1,))) variance = squares_sum * ast_const(dtypes.float, 0.04, (15, 1, 1, 35)) std = variance.alu(UnaryOps.SQRT) - store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((15, 1, 1, 35)).to_uop(), std)) - sink = UOp(UOps.SINK, src=(store,)) + store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((15, 1, 1, 35)).to_uop(), std)) + sink = UOp(Ops.SINK, src=(store,)) wanna_output = x.numpy().std(axis=1, ddof=0).reshape((15,1,1,35)) helper_linearizer_ast(sink, [x], wanna_output=[wanna_output]) @@ -486,41 +486,41 @@ class TestLinearizer(unittest.TestCase): # TODO: Similar error to test_multiout_intermediate_multireduce (implicit expand vs shape mismatch) Tensor.manual_seed(0) x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() - g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] - first_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop())) - first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) + g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] + first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop())) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1)) - second_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 35, 1)).to_uop())) + second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 35, 1)).to_uop())) squares = (second_x+neg_mean)*(second_x+neg_mean) - squares_sum = UOp(UOps.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (2,))) + squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (2,))) variance = squares_sum * ast_const(dtypes.float, 1/35, (15, 25, 1, 1)) std = variance.alu(UnaryOps.SQRT) - store_mean = UOp(UOps.STORE, src=(g1, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), neg_mean)) - store_std = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), std)) - sink = UOp(UOps.SINK, src=(store_std, store_mean)) + store_mean = UOp(Ops.STORE, src=(g1, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), neg_mean)) + store_std = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), std)) + sink = UOp(Ops.SINK, src=(store_std, store_mean)) wanna_output = [x.numpy().std(axis=2, ddof=0).reshape(15,25,1,1), x.numpy().mean(axis=2).reshape(15,25,1,1)] lins = helper_linearizer_ast(sink, [x], wanna_output=wanna_output) for k in lins: - assert len([u for u in k.uops if u.op is UOps.DEFINE_ACC]) == 2, "got more than two accs (implies the kernel didn't reuse the mean reduce)" + assert len([u for u in k.uops if u.op is Ops.DEFINE_ACC]) == 2, "got more than two accs (implies the kernel didn't reuse the mean reduce)" @unittest.skipIf(CI and Device.DEFAULT in {"PTX", "AMD", "NV"}, "ocelot/remu doesn't have multiple wave syncs yet") def test_var_multireduce(self): Tensor.manual_seed(0) x = Tensor.randn(3, 27, 32, dtype=dtypes.float).realize() - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] + g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] # push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD - first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32)).to_uop())) - first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) + first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32)).to_uop())) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) neg_mean = first_reduce * ast_const(dtypes.float, -0.03125, (3, 27, 32, 1)) # store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 32, 1)).to_uop(), mean)) # verify_lazyop(store) - second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 32, 1)).to_uop())) + second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 32, 1)).to_uop())) squares = (second_x+neg_mean)*(second_x+neg_mean) - squares_sum = UOp(UOps.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (2,))) + squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (2,))) variance = squares_sum * ast_const(dtypes.float, 0.03125, (3, 27, 1, 1)) - store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 1, 1)).to_uop(), variance)) - sink = UOp(UOps.SINK, src=(store,)) + store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 1, 1)).to_uop(), variance)) + sink = UOp(Ops.SINK, src=(store,)) wanna_output = x.numpy().var(axis=2, ddof=0).reshape((3,27,1,1)) helper_linearizer_ast(sink, [x], wanna_output=[wanna_output]) # tinygrad ref @@ -530,17 +530,17 @@ class TestLinearizer(unittest.TestCase): @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") def test_softmax_multireduce(self): x = Tensor.rand(4, 32).realize() - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32,)).expand((4, 32, 32)).to_uop())) - max_x = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.MAX, (2,))) - second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1,)).to_uop())) + g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] + first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32,)).expand((4, 32, 32)).to_uop())) + max_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.MAX, (2,))) + second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1,)).to_uop())) centered_x = second_x+max_x*ast_const(dtypes.float, -1, (4, 32, 1)) exp_x = centered_x.alu(UnaryOps.EXP2) - sum_exp_x = UOp(UOps.REDUCE_AXIS, dtypes.float, (exp_x,), (BinaryOps.ADD, (1,))) + sum_exp_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (exp_x,), (BinaryOps.ADD, (1,))) # y = exp_x * sum_exp_x.alu(UnaryOps.RECIP) # kernels cannot do a return to full shape recip_sum_exp_x = sum_exp_x.alu(UnaryOps.RECIP) - store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((4,1,1)).to_uop(), recip_sum_exp_x)) - sink = UOp(UOps.SINK, src=(store,)) + store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((4,1,1)).to_uop(), recip_sum_exp_x)) + sink = UOp(Ops.SINK, src=(store,)) expected = 1/np.exp2(x.numpy() - x.numpy().max(axis=-1, keepdims=True)).sum(axis=-1, keepdims=True).reshape(4,1,1) helper_linearizer_ast(sink, [x], wanna_output=[expected]) @@ -556,24 +556,24 @@ class TestLinearizer(unittest.TestCase): View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False))) arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384)) arange_axis = (3,) - arange = UOp(UOps.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (BinaryOps.ADD, arange_axis)) + arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (BinaryOps.ADD, arange_axis)) output_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape)) out = arange+ast_const(dtypes.int, -1, output_shape) - store = UOp(UOps.STORE, src=(UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0), ShapeTracker.from_shape(output_shape).to_uop(), out)) - sink = UOp(UOps.SINK, src=(store,)) + store = UOp(Ops.STORE, src=(UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0), ShapeTracker.from_shape(output_shape).to_uop(), out)) + sink = UOp(Ops.SINK, src=(store,)) helper_linearizer_ast(sink, [], wanna_output=[real_arange]) with Context(DEBUG=0, NOOPT=0): np.testing.assert_equal(tiny.numpy(), real_arange) @unittest.skipIf(CI and Device.DEFAULT in {"PTX", "AMD", "NV"}, "very slow") def test_indexing_multireduce(self): - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - g2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2) + g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] + g2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2) arange_input_st = ShapeTracker(views=(View(shape=(16385, 32767), strides=(0, 0), offset=0, mask=((0, 16385), (16383, 32767)), contiguous=False), \ View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False))) # TODO: do this arange broadcast in the scheduler arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384)) arange_axis = (3,) - arange = UOp(UOps.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (BinaryOps.ADD, arange_axis)) + arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (BinaryOps.ADD, arange_axis)) arange_out_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape)) arange = arange+ast_const(dtypes.int, -1, arange_out_shape) # p2: the indexing @@ -581,13 +581,13 @@ class TestLinearizer(unittest.TestCase): data1 = (g1, ShapeTracker.from_shape(dataset.shape).reshape((1, 16384, 256, 1)).expand(arange_out_shape).to_uop()) idxs = Tensor([0,3,5,6]).realize() data2 = (g2, ShapeTracker.from_shape((4,)+(1,)*(len(arange_out_shape)-1)).expand(arange_out_shape).to_uop()) - arange_eq = arange.alu(BinaryOps.CMPNE, UOp(UOps.LOAD, dtypes.int, data2)).alu(BinaryOps.CMPNE, ast_const(dtypes.bool, True, arange_out_shape)) - reduce_input = UOp(UOps.LOAD, dataset.dtype, data1)*UOp(UOps.CAST, dataset.dtype.scalar(), src=(arange_eq,)) + arange_eq = arange.alu(BinaryOps.CMPNE, UOp(Ops.LOAD, dtypes.int, data2)).alu(BinaryOps.CMPNE, ast_const(dtypes.bool, True, arange_out_shape)) + reduce_input = UOp(Ops.LOAD, dataset.dtype, data1)*UOp(Ops.CAST, dataset.dtype.scalar(), src=(arange_eq,)) out_axis = (1,) - out = UOp(UOps.REDUCE_AXIS, reduce_input.dtype, (reduce_input,), (BinaryOps.ADD, out_axis)) + out = UOp(Ops.REDUCE_AXIS, reduce_input.dtype, (reduce_input,), (BinaryOps.ADD, out_axis)) output_shape = tuple(1 if i in out_axis else s for i,s in enumerate(arange_out_shape)) - store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape(output_shape).to_uop(), out)) - sink = UOp(UOps.SINK, src=(store,)) + store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape(output_shape).to_uop(), out)) + sink = UOp(Ops.SINK, src=(store,)) real_index = dataset.numpy()[idxs.numpy()].reshape(4, 1, 256, 1) helper_linearizer_ast(sink, [dataset, idxs], wanna_output=[real_index]) @@ -596,29 +596,29 @@ class TestLinearizer(unittest.TestCase): t = Tensor.randn(10, 20).realize() t_max = t.max((0,)).realize() real_argmax = np.argmax(t.numpy(), axis=0, keepdims=False).reshape(1, 20, 1) - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa E501 - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa E501 + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( ast_const(dtypes.int, st=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), val=10), - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.MUL, src=( ast_const(dtypes.int, -1, (1, 20, 1)), - UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.MAX, (0,)), src=( - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=( - UOp(UOps.CAST, dtypes.int, arg=None, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(20, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa E501 - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa E501 + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.MAX, (0,)), src=( + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.MUL, src=( + UOp(Ops.CAST, dtypes.int, arg=None, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(20, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa E501 + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa E501 ast_const(dtypes.bool, True, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)), # noqa E501 - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=( ast_const(dtypes.int, -1, st=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 20, 10), strides=(1, 0, 20), offset=0, mask=None, contiguous=False)))),)), # noqa E501 ast_const(dtypes.int, 10, (10, 20, 1)))),)),)),)),)), ast_const(dtypes.int, -1, (1, 20, 1)),)),)),)) @@ -628,29 +628,29 @@ class TestLinearizer(unittest.TestCase): t = Tensor.randn(10, 20).realize() t_max = t.max().realize() real_argmax = np.argmax(t.numpy()) - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501 - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501 + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( ast_const(dtypes.int, 200, (1, 1)), - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.MUL, src=( ast_const(dtypes.int, -1, (1, 1)), - UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.MAX, (0,)), src=( - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=( - UOp(UOps.CAST, dtypes.int, arg=None, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa: E501 - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa: E501 + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.MAX, (0,)), src=( + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.MUL, src=( + UOp(Ops.CAST, dtypes.int, arg=None, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa: E501 + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa: E501 ast_const(dtypes.bool, True, (200, 1)),)),)), - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( ast_const(dtypes.int, -1, st=ShapeTracker(views=(View(shape=(201, 399), strides=(0, 0), offset=0, mask=((0, 201), (199, 399)), contiguous=False), View(shape=(200, 200), strides=(1, 400), offset=0, mask=None, contiguous=False)))),)), # noqa: E501 ast_const(dtypes.int, 200, (200, 1)),)),)),)),)),)), ast_const(dtypes.int, -1, (1, 1)),)),)),)) @@ -669,21 +669,21 @@ class TestLinearizer(unittest.TestCase): # [Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)] ] - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - x_ld0 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop())) - x_ld1 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop())) - r0 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.ADD, (1,))) - r1 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),),(BinaryOps.ADD, (0,))) - store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((1,1,N)).to_uop(), r1)) - sink = UOp(UOps.SINK, src=(store,)) + g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] + x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop())) + x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop())) + r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.ADD, (1,))) + r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),),(BinaryOps.ADD, (0,))) + store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((1,1,N)).to_uop(), r1)) + sink = UOp(Ops.SINK, src=(store,)) helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=0, keepdims=True)).sum(axis=0).reshape(1,1,N)], opts=opts) - x_ld0 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)).to_uop())) - x_ld1 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, N, 1)).to_uop())) - r0 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.ADD, (2,))) - r1 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (BinaryOps.ADD, (1,))) - store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((N,1,1)).to_uop(), r1)) - sink = UOp(UOps.SINK, src=(store,)) + x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)).to_uop())) + x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, N, 1)).to_uop())) + r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.ADD, (2,))) + r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (BinaryOps.ADD, (1,))) + store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((N,1,1)).to_uop(), r1)) + sink = UOp(Ops.SINK, src=(store,)) helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(N,1,1)], opts=opts) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") @@ -696,21 +696,21 @@ class TestLinearizer(unittest.TestCase): [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),] ] - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] - x_ld0 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop())) - x_ld1 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop())) - r0 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.MAX, (1,))) - r1 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),), (BinaryOps.MAX, (0,))) - store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((1,1,N)).to_uop(), r1)) - sink = UOp(UOps.SINK, src=(store,)) + g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] + x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop())) + x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop())) + r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.MAX, (1,))) + r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),), (BinaryOps.MAX, (0,))) + store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((1,1,N)).to_uop(), r1)) + sink = UOp(Ops.SINK, src=(store,)) helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().max(axis=0, keepdims=True)).max(axis=0).reshape(1,1,N)], opts=opts) - x_ld0 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)).to_uop())) - x_ld1 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, N, 1)).to_uop())) - r0 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.MAX, (2,))) - r1 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (BinaryOps.MAX, (1,))) - store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((N,1,1)).to_uop(), r1)) - sink = UOp(UOps.SINK, src=(store,)) + x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)).to_uop())) + x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, N, 1)).to_uop())) + r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.MAX, (2,))) + r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (BinaryOps.MAX, (1,))) + store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((N,1,1)).to_uop(), r1)) + sink = UOp(Ops.SINK, src=(store,)) helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().max(axis=1, keepdims=True)).max(axis=1).reshape(N,1,1)], opts=opts) @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") @@ -728,31 +728,31 @@ class TestLinearizer(unittest.TestCase): wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=1,keepdims=True), a.numpy(), b.numpy())).sum(axis=1),0.0,1.0).reshape((N,1,1)) # noqa: E501 ld0 = x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)) ld1 = x.lazydata.st.reshape((N, N, 1)) - ast = UOp(UOps.SINK, src=( - UOp(UOps.STORE, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),))), - UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( + ast = UOp(Ops.SINK, src=( + UOp(Ops.STORE, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),))), + UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( ast_const(dtypes.float, 0.5*N, (N, 1, 1)), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), ld1.to_uop(),)), - UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( + UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( ast_const(dtypes.float, 0.75*N, (N, N, 1)), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( - UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.LOAD, dtypes.float, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), ld0.to_uop(),)),)),)), - UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501 - UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)), # noqa: E501 + UOp(Ops.LOAD, dtypes.float, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501 + UOp(Ops.LOAD, dtypes.float, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)), # noqa: E501 ast_const(dtypes.float, 0.0, (N, 1, 1)), ast_const(dtypes.float, 1.0, (N, 1, 1)),)),)),)) @@ -761,31 +761,31 @@ class TestLinearizer(unittest.TestCase): ld0 = x.lazydata.st.reshape((1, N, N)).expand((N,N,N)) ld1 = x.lazydata.st.reshape((N, 1, N)) wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=0,keepdims=True), a.numpy(), b.numpy())).sum(axis=0),0.0,1.0).reshape(1,1,N) # noqa: E501 - ast = UOp(UOps.SINK, src=( - UOp(UOps.STORE, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, N), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( + ast = UOp(Ops.SINK, src=( + UOp(Ops.STORE, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, N), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( ast_const(dtypes.float, 0.5*N, (1, 1, N)), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), ld1.to_uop(),)), - UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( + UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( ast_const(dtypes.float, 0.75*N, (N, 1, N)), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( - UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.LOAD, dtypes.float, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), ld0.to_uop(),)),)),)), - UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), # noqa: E501 - UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), # noqa: E501 + UOp(Ops.LOAD, dtypes.float, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), # noqa: E501 + UOp(Ops.LOAD, dtypes.float, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), # noqa: E501 ast_const(dtypes.float, 0.0, (1, 1, N)), ast_const(dtypes.float, 1.0, (1, 1, N)),)),)),)) @@ -797,31 +797,31 @@ class TestLinearizer(unittest.TestCase): ld0 = x.lazydata.st.reshape((1,1,N,N)).expand((N,N,N,N)) ld1 = x.lazydata.st.reshape((N,N,1,1)) wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(keepdims=True), a.numpy(), b.numpy())).sum(keepdims=True),0.0,1.0).reshape((1,1,1,1))# noqa: E501 - ast = UOp(UOps.SINK, src=( - UOp(UOps.STORE, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=True),))), - UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( + ast = UOp(Ops.SINK, src=( + UOp(Ops.STORE, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=True),))), + UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( ast_const(dtypes.float, 0.5*N, (1, 1, 1, 1)), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(N, 1, 0, 0), offset=0, mask=None, contiguous=True),))),)), # noqa: E501 - UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(N, 1, 0, 0), offset=0, mask=None, contiguous=True),))),)), # noqa: E501 + UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( ast_const(dtypes.float, 0.75*N, (N, N, 1, 1)), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=( - UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, N, N), strides=(0, 0, N, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501 - UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501 - UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)), # noqa: E501 + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=( + UOp(Ops.LOAD, dtypes.float, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, N, N), strides=(0, 0, N, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501 + UOp(Ops.LOAD, dtypes.float, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501 + UOp(Ops.LOAD, dtypes.float, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)), # noqa: E501 ast_const(dtypes.float, 0.0, (1, 1, 1, 1)), ast_const(dtypes.float, 1.0, (1, 1, 1, 1)),)),)),)) helper_linearizer_ast(ast, [x,a,b], opts=[[Opt(OptOps.PADTO, 0, 32)],], wanna_output=[wanna_output]) @@ -829,21 +829,21 @@ class TestLinearizer(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_end_local(self): - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=i) for i in range(2)] - load = UOp(UOps.LOAD, dtypes.int, (g1, ShapeTracker.from_shape((32,)).to_uop())) - reduce = UOp(UOps.REDUCE_AXIS, dtypes.int, (load,), (BinaryOps.ADD, (0,))) - store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((1,)).to_uop(), reduce)) - sink = UOp(UOps.SINK, src=(store,)) + g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=i) for i in range(2)] + load = UOp(Ops.LOAD, dtypes.int, (g1, ShapeTracker.from_shape((32,)).to_uop())) + reduce = UOp(Ops.REDUCE_AXIS, dtypes.int, (load,), (BinaryOps.ADD, (0,))) + store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((1,)).to_uop(), reduce)) + sink = UOp(Ops.SINK, src=(store,)) load_t = Tensor.full(load.st_arg.shape, 1).contiguous().realize() k = helper_linearizer_ast(sink, [load_t], wanna_output=[load_t.numpy().sum()])[1] - self.assertEqual(k.uops[-1].op, UOps.ENDIF) - self.assertLess(k.uops.index([x for x in k.uops if x.op is UOps.STORE][-1]), k.uops.index(k.uops[-1])) + self.assertEqual(k.uops[-1].op, Ops.ENDIF) + self.assertLess(k.uops.index([x for x in k.uops if x.op is Ops.STORE][-1]), k.uops.index(k.uops[-1])) def test_two_nested_range(self): a = Tensor.randn(2, ).realize() out = a.reshape(2, 1).expand(2, 3).sum() lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)).sum()])[0] - ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE] + ranges = [i for i,u in enumerate(lin.uops) if u.op is Ops.RANGE] assert len(ranges) == 1 # NOTE: it collapses now # RANGE -> LOAD -> RANGE -> ASSIGN #assert any(x.op is UOps.LOAD for x in lin.uops[ranges[0]:ranges[1]]) @@ -852,7 +852,7 @@ class TestLinearizer(unittest.TestCase): a = Tensor.randn(2, ).realize() out = a.reshape(2, 1).expand(2, 3).expand(2, 2, 3).sum() lin = helper_linearizer_opt(out, wanna_output=[np.broadcast_to(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)), (2, 2, 3)).sum()])[0] - ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE] + ranges = [i for i,u in enumerate(lin.uops) if u.op is Ops.RANGE] assert len(ranges) == 1 # NOTE: it collapses now # RANGE -> RANGE -> LOAD -> RANGE -> ASSIGN # NOTE: nothing should toposort between the first two ranges @@ -863,27 +863,27 @@ class TestLinearizer(unittest.TestCase): a = Tensor([2, 2]).realize() out = a.reshape(2, 1).pad(((1, 1), (1, 1)), 2).sum() lin = helper_linearizer_opt(out, wanna_output=[24])[0] - ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE] + ranges = [i for i,u in enumerate(lin.uops) if u.op is Ops.RANGE] # RANGE -> ALU -> RANGE -> ALU + LOAD -> ASSIGN - assert any(x.op is UOps.ALU for x in lin.uops[ranges[0]:ranges[1]]) - assert not any(x.op is UOps.LOAD for x in lin.uops[ranges[0]:ranges[1]]) - assert any(x.op in {UOps.ALU, UOps.LOAD} for x in lin.uops[ranges[1]:]) + assert any(x.op is Ops.ALU for x in lin.uops[ranges[0]:ranges[1]]) + assert not any(x.op is Ops.LOAD for x in lin.uops[ranges[0]:ranges[1]]) + assert any(x.op in {Ops.ALU, Ops.LOAD} for x in lin.uops[ranges[1]:]) def test_range_outer_op_before_phi(self): a = Tensor.randn(4, 1).realize() b = Tensor.randn(1, 1).realize() out = (a + b[0]).sum() + b[0] lin = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()])[0] - ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE] + ranges = [i for i,u in enumerate(lin.uops) if u.op is Ops.RANGE] # LOAD -> RANGE -> LOAD -> ASSIGN - assert len([x for x in lin.uops[:ranges[0]] if x.op is UOps.LOAD]) == 1 + assert len([x for x in lin.uops[:ranges[0]] if x.op is Ops.LOAD]) == 1 def test_range_outer_op_before_phi_nested_range(self): a = Tensor.randn(2, ).realize() b = Tensor.randn(1, 1).realize() out = (a.reshape(2, 1).expand(2, 3) + b[0]).sum() + b[0] lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3)) + b.numpy()[0]).sum() + b.numpy()])[0] - ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE] + ranges = [i for i,u in enumerate(lin.uops) if u.op is Ops.RANGE] assert len(ranges) == 1 # NOTE: it collapses now #if getenv("PTX"): # LOAD -> RANGE -> CAST -> ALU -> ALU -> LOAD -> ALU -> RANGE -> ALU -> ASSIGN @@ -901,16 +901,16 @@ class TestLinearizer(unittest.TestCase): out = a.sum() * a.sum() lin = helper_linearizer_opt(out, wanna_output=[a.numpy().sum()*a.numpy().sum()])[0] # RANGE -> LOAD -> ASSIGN -> ALU - end = max(i for i,u in enumerate(lin.uops) if u.op is UOps.ENDRANGE) - assert lin.uops[end+1].op is UOps.ALU + end = max(i for i,u in enumerate(lin.uops) if u.op is Ops.ENDRANGE) + assert lin.uops[end+1].op is Ops.ALU def test_range_outer_op_after_phi_nested_range(self): a = Tensor.randn(2, ).realize() out = a.reshape(2, 1).expand(2, 3).sum() + a.reshape(2, 1).expand(2, 3).sum() lin = helper_linearizer_opt(out, wanna_output=[(np.broadcast_to(a.numpy().reshape(2, 1), (2, 3))).sum()*2])[0] # RANGE -> LOAD -> ASSIGN -> ALU - end = max(i for i,u in enumerate(lin.uops) if u.op is UOps.ENDRANGE) - assert lin.uops[end+1].op is UOps.ALU + end = max(i for i,u in enumerate(lin.uops) if u.op is Ops.ENDRANGE) + assert lin.uops[end+1].op is Ops.ALU def test_load_dedup(self): # for different leaves in the AST, the same loads may occur. @@ -922,7 +922,7 @@ class TestLinearizer(unittest.TestCase): k = Kernel(create_schedule([r.lazydata])[-1].ast) k.upcast() k.linearize() - num_loads = len([uop for uop in k.uops if uop.op is UOps.LOAD]) + num_loads = len([uop for uop in k.uops if uop.op is Ops.LOAD]) assert num_loads <= 4, "more load uops than needed" assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?" @@ -931,15 +931,15 @@ class TestLinearizer(unittest.TestCase): # make sure const buffers are differentiated from local and mem buffers ST, DT = ShapeTracker(views=(View(shape=((1,)), strides=(0, 0), offset=0, mask=None, contiguous=False),)).to_uop(), dtypes.int VAL = ast_const(DT, 2, ST.arg.shape) - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, DT.ptr(), arg=i) for i in range(2)] + g0, g1 = [UOp(Ops.DEFINE_GLOBAL, DT.ptr(), arg=i) for i in range(2)] # data1[0] + VAL - a = UOp(UOps.LOAD, DT, (g1, ST)) + VAL + a = UOp(Ops.LOAD, DT, (g1, ST)) + VAL # (literal const 1) + VAL b = ast_const(DT, 1, ST.arg.shape) + VAL - store = UOp(UOps.STORE, src=(g0, ST, (a+b))) - sink = UOp(UOps.SINK, src=(store,)) + store = UOp(Ops.STORE, src=(g0, ST, (a+b))) + sink = UOp(Ops.SINK, src=(store,)) lin = Kernel(sink) lin.linearize() assert len(lin.uops) <= 9, "too many uops" @@ -953,7 +953,7 @@ class TestLinearizer(unittest.TestCase): k = Kernel(create_schedule([r.lazydata])[-1].ast) k.upcast() k.linearize() - num_ops = len([uop for uop in k.uops if uop.op is UOps.ALU]) + num_ops = len([uop for uop in k.uops if uop.op is Ops.ALU]) assert num_ops <= 1, "more alu uops than needed" @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") @@ -965,8 +965,8 @@ class TestLinearizer(unittest.TestCase): k.upcast() k.upcast() k.linearize() - accs = [u for u in k.uops if u.op is UOps.DEFINE_ACC] - stores = [u for u in k.uops if u.op is UOps.STORE] + accs = [u for u in k.uops if u.op is Ops.DEFINE_ACC] + stores = [u for u in k.uops if u.op is Ops.STORE] assert len(accs) == 0 # it's removed now assert len(stores) == 1 assert stores[0].src[-1].dtype == dtypes.float.vec(4) @@ -982,14 +982,14 @@ class TestLinearizer(unittest.TestCase): k.hand_coded_optimizations() k.linearize() - stores = [u for u in k.uops if u.op is UOps.STORE] + stores = [u for u in k.uops if u.op is Ops.STORE] # the first store is to lds and can be upcasted assert stores[0].src[-1].dtype == dtypes.float.vec(4) - assert any(x.op is UOps.DEFINE_LOCAL for x in stores[0].sparents) + assert any(x.op is Ops.DEFINE_LOCAL for x in stores[0].sparents) # the second store is to gds with no upcasts assert stores[1].src[-1].dtype == dtypes.float - assert any(x.op is UOps.DEFINE_GLOBAL for x in stores[1].sparents) + assert any(x.op is Ops.DEFINE_GLOBAL for x in stores[1].sparents) def test_zero_fold(self): a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize() @@ -998,7 +998,7 @@ class TestLinearizer(unittest.TestCase): k = Kernel(create_schedule([r.lazydata])[-1].ast) k.upcast() k.linearize() - num_ops = len([uop for uop in k.uops if uop.op is UOps.ALU]) + num_ops = len([uop for uop in k.uops if uop.op is Ops.ALU]) assert num_ops == 0, "more alu uops than needed" def test_sum_acc_dtype(self): @@ -1007,14 +1007,14 @@ class TestLinearizer(unittest.TestCase): a = Tensor([1, 2, 3], dtype=tensor_dtype).sum() k = Kernel(create_schedule([a.lazydata])[-1].ast) k.linearize() - local = [uop for uop in k.uops if uop.op is UOps.DEFINE_ACC] + local = [uop for uop in k.uops if uop.op is Ops.DEFINE_ACC] assert local[0].dtype == acc_dtype def test_arg_acc_dtype(self): def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType): k = Kernel(create_schedule([c.lazydata])[-1].ast) k.linearize() - local = [uop for uop in k.uops if uop.op is UOps.DEFINE_ACC] + local = [uop for uop in k.uops if uop.op is Ops.DEFINE_ACC] assert local[0].dtype == expected_dtype tests = ( @@ -1081,7 +1081,7 @@ class TestLinearizer(unittest.TestCase): k = Kernel(realized_ast) k.apply_tensor_cores(1, axis=axis, tc_opt=2) k.linearize() - assert len([uop for uop in k.uops if uop.op is UOps.WMMA]) > 0, "tensor core not triggered" + assert len([uop for uop in k.uops if uop.op is Ops.WMMA]) > 0, "tensor core not triggered" assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included" prg = CompiledRunner(k.to_program()) @@ -1105,8 +1105,8 @@ class TestLinearizer(unittest.TestCase): r = x.matmul(y, acc_dtype=tc.dtype_out) k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1] for u in k.uops: - if u.op is UOps.WMMA: - assert u.src[-1].src[0].op != UOps.ASSIGN + if u.op is Ops.WMMA: + assert u.src[-1].src[0].op != Ops.ASSIGN @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") @unittest.skipIf(Device.DEFAULT in {"CLANG"}, "CLANG does not support using a different type for accumulation") @@ -1116,9 +1116,9 @@ class TestLinearizer(unittest.TestCase): r = x.matmul(y, acc_dtype=tc.dtype_out) k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1] for u in k.uops: - if u.op is UOps.WMMA: + if u.op is Ops.WMMA: #assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2])) - assert u.src[-1].src[0].op != UOps.ASSIGN + assert u.src[-1].src[0].op != Ops.ASSIGN @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") @unittest.skipIf(Device.DEFAULT in {"CLANG"}, "CLANG does not support using a different type for accumulation") @@ -1129,9 +1129,9 @@ class TestLinearizer(unittest.TestCase): r = x.matmul(y, acc_dtype=tc.dtype_out).relu() k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4)]], apply_tc=True, atol=3e-2, rtol=1e-3)[-1] for u in k.uops: - if u.op is UOps.WMMA: + if u.op is Ops.WMMA: #assert u.src[-1].dtype == dtypes.float.vec(prod(tc.thread_local_sizes[2])) - assert u.src[-1].src[0].op != UOps.ASSIGN + assert u.src[-1].src[0].op != Ops.ASSIGN @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_simple_unroll_no_between_phi_dependencies(self): @@ -1140,17 +1140,17 @@ class TestLinearizer(unittest.TestCase): k = helper_linearizer_opt(r, [[Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4)]])[-1] # the uops graph is RANGE -> DEFINE_ACC -> 4x ALU -> 4x ASSIGN -> ENDRANGE for u in k.uops: - if u.op is UOps.ASSIGN: - assert u.src[1].op is UOps.ALU + if u.op is Ops.ASSIGN: + assert u.src[1].op is Ops.ALU # children of ASSIGN are placed after ENDRANGE - if any(x.op is UOps.ASSIGN for x in u.src): - end_range = [i for i, x in enumerate(k.uops) if x.op is UOps.ENDRANGE][0] + if any(x.op is Ops.ASSIGN for x in u.src): + end_range = [i for i, x in enumerate(k.uops) if x.op is Ops.ENDRANGE][0] assert end_range < k.uops.index(u) def test_grouped_dims(self): def _assert_grouped_dims(prefix, dims, max_sizes, reverse_dims, expected_sizes): idxs = get_grouped_dims(prefix, dims, max_sizes, reverse_dims) - loop_idxs = dedup(flatten([[y for y in x.sparents if y.op is UOps.SPECIAL] for x in idxs])) + loop_idxs = dedup(flatten([[y for y in x.sparents if y.op is Ops.SPECIAL] for x in idxs])) loop_idxs = sorted(loop_idxs, key=lambda uop: uop.arg[0]) sizes = [x.arg[1] for x in loop_idxs] assert len(idxs) == len(dims), f"expected idxs to have same length as dims {len(dims)}, got {len(idxs)}" @@ -1207,7 +1207,7 @@ class TestLinearizer(unittest.TestCase): # shrink so that the dims do not collapse t = Tensor.ones(5, 6, 7).contiguous().realize().shrink(((0, 4), (0, 5), (0, 6))) k = helper_linearizer_opt(t+1)[0] - idxs = dedup([uop for uop in k.uops if uop.op is UOps.SPECIAL]) + idxs = dedup([uop for uop in k.uops if uop.op is Ops.SPECIAL]) idxs = sorted(idxs, key=lambda uop: uop.arg[0]) assert idxs[0].arg == ('gidx0', 6), idxs[0].arg assert idxs[1].arg == ('gidx1', 5), idxs[1].arg @@ -1215,7 +1215,7 @@ class TestLinearizer(unittest.TestCase): def test_div_collapse(self): def helper(t, msg, max_ops=0): - sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is UOps.SINK] + sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is Ops.SINK] assert len(sched) == 1 lin = Kernel(sched[0].ast) @@ -1236,10 +1236,10 @@ class TestLinearizer(unittest.TestCase): def test_sum_collapse(self): t = Tensor([2]).reshape(1, 1).expand(256, 256).sum() - sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is UOps.SINK] + sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is Ops.SINK] assert len(sched) == 1 lin = Kernel(sched[0].ast) - assert not any(u.op is UOps.RANGE for u in lin.linearize().uops), "found loop in sum collapse" + assert not any(u.op is Ops.RANGE for u in lin.linearize().uops), "found loop in sum collapse" def test_assign_fold(self): a = Tensor.ones(4, 4).contiguous().realize() @@ -1267,10 +1267,10 @@ class TestLinearizer(unittest.TestCase): k = helper_linearizer_opt(t)[-1] uops = list(k.linearize().uops) # ignore kernel optimized IF statements for now - if if_op:=next((u for u in uops if u.op is UOps.IF), None): + if if_op:=next((u for u in uops if u.op is Ops.IF), None): uops = uops[:uops.index(if_op)] - assert len(set([u.op for u in uops if u.op in {UOps.RANGE, UOps.SPECIAL}])) == 1, "has either specials or ranges, not both" - assert len([u for u in uops if u.op is UOps.ASSIGN]) == 0, "ASSIGN should have been simplified" + assert len(set([u.op for u in uops if u.op in {Ops.RANGE, Ops.SPECIAL}])) == 1, "has either specials or ranges, not both" + assert len([u for u in uops if u.op is Ops.ASSIGN]) == 0, "ASSIGN should have been simplified" # TODO: once uops track min/max this will be fixed #assert len([u for u in uops if u.arg is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops" @@ -1296,7 +1296,7 @@ class TestLinearizer(unittest.TestCase): out = x.matmul(y) k = helper_linearizer_opt(out)[-1] # check that the float4 cast collapses - store_vals = [u.src[-1] for u in k.uops if u.op is UOps.STORE] + store_vals = [u.src[-1] for u in k.uops if u.op is Ops.STORE] for val in store_vals: assert val.dtype == dtypes.float.vec(4) # and val.op is not UOps.VECTORIZE @@ -1319,8 +1319,8 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn((4,3,6,6)).realize() out = x.flip((0,1)).contiguous() k = helper_linearizer_opt(out)[-1] - store_val = [u.src[-1] for u in k.uops if u.op is UOps.STORE][0] - assert store_val.dtype == dtypes.float.vec(4) and store_val.op is not UOps.VECTORIZE + store_val = [u.src[-1] for u in k.uops if u.op is Ops.STORE][0] + assert store_val.dtype == dtypes.float.vec(4) and store_val.op is not Ops.VECTORIZE @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @@ -1332,16 +1332,16 @@ class TestLinearizer(unittest.TestCase): Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)] # upcast accs in both reduces k = helper_linearizer_opt(out, opts=[opt])[-1] def get_recursive(uop): return set.union(set(uop.src), [uop], *[get_recursive(v) for v in uop.src]) - local_stores = [u for u in k.uops if u.op is UOps.STORE and any(x.op is UOps.DEFINE_LOCAL for x in get_recursive(u.src[0]))] - global_stores = [u for u in k.uops if u.op is UOps.STORE and any(x.op is UOps.DEFINE_GLOBAL for x in get_recursive(u.src[0]))] - barrier = [u for u in k.uops if u.op is UOps.BARRIER][0] + local_stores = [u for u in k.uops if u.op is Ops.STORE and any(x.op is Ops.DEFINE_LOCAL for x in get_recursive(u.src[0]))] + global_stores = [u for u in k.uops if u.op is Ops.STORE and any(x.op is Ops.DEFINE_GLOBAL for x in get_recursive(u.src[0]))] + barrier = [u for u in k.uops if u.op is Ops.BARRIER][0] # check that the float4 cast collapses for all stores for store in local_stores+global_stores: assert store.src[-1].dtype.count > 1 # and store.src[2].op is not UOps.VECTORIZE # # check the children's vins # TODO: src ALU are not the same, should it? # assert barrier.src == tuple(local_stores) - assert len([u for u in k.uops if u.op is UOps.IF and u.src[-1] == barrier]) == 1 + assert len([u for u in k.uops if u.op is Ops.IF and u.src[-1] == barrier]) == 1 @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @@ -1350,7 +1350,7 @@ class TestLinearizer(unittest.TestCase): x, y = Tensor.rand(1,128), Tensor.rand(128, 128) r = (x@y).relu() k = helper_linearizer_opt(r)[-1] - stores = [u for u in k.uops if u.op is UOps.STORE] + stores = [u for u in k.uops if u.op is Ops.STORE] # the float4 value stores directly in lds and we skip upcast self.assertEqual(stores[0].src[-1].dtype, dtypes.float.vec(4)) @@ -1363,49 +1363,49 @@ class TestLinearizer(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_skip_unmatching_upcasts(self): Tensor.manual_seed(0) - ast = UOp(UOps.SINK, src=( - UOp(UOps.STORE, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),))), - UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501 + ast = UOp(Ops.SINK, src=( + UOp(Ops.STORE, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),))), + UOp(Ops.LOAD, dtypes.float, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501 opt = [ Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2) ] k = helper_linearizer_ast(ast, [Tensor.randn(240*40).realize()], opts=[opt])[-1] - out = [u for u in k.uops if u.op is UOps.STORE][0] - assert out.src[-1].op is UOps.VECTORIZE and out.src[-1].dtype == dtypes.float.vec(4) + out = [u for u in k.uops if u.op is Ops.STORE][0] + assert out.src[-1].op is Ops.VECTORIZE and out.src[-1].dtype == dtypes.float.vec(4) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") def test_skip_unmatching_upcasts_with_gep(self): Tensor.manual_seed(0) - ast = UOp(UOps.SINK, src=( - UOp(UOps.STORE, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),))), - UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501 + ast = UOp(Ops.SINK, src=( + UOp(Ops.STORE, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),))), + UOp(Ops.LOAD, dtypes.float, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501 opt = [Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=2)] k = helper_linearizer_ast(ast, [Tensor.randn(8*32).realize()], opts=[opt])[-1] - out = [u for u in k.uops if u.op is UOps.STORE][0] - assert out.src[-1].op is UOps.VECTORIZE and out.src[-1].dtype.count != 1 + out = [u for u in k.uops if u.op is Ops.STORE][0] + assert out.src[-1].op is Ops.VECTORIZE and out.src[-1].dtype.count != 1 @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need backends that support float4") class TestFloat4(unittest.TestCase): @staticmethod def count_float4(k, n=4): - return (len([uop for uop in k.uops if uop.op is UOps.LOAD and uop.dtype == dtypes.float.vec(n)]), - len([uop for uop in k.uops if uop.op is UOps.STORE and uop.src[-1].dtype == dtypes.float.vec(n)])) + return (len([uop for uop in k.uops if uop.op is Ops.LOAD and uop.dtype == dtypes.float.vec(n)]), + len([uop for uop in k.uops if uop.op is Ops.STORE and uop.src[-1].dtype == dtypes.float.vec(n)])) @staticmethod def count_half4(k): - return (len([uop for uop in k.uops if uop.op is UOps.LOAD and uop.dtype == dtypes.half.vec(4)]), - len([uop for uop in k.uops if uop.op is UOps.STORE and uop.src[-1].dtype == dtypes.half.vec(4)])) + return (len([uop for uop in k.uops if uop.op is Ops.LOAD and uop.dtype == dtypes.half.vec(4)]), + len([uop for uop in k.uops if uop.op is Ops.STORE and uop.src[-1].dtype == dtypes.half.vec(4)])) # TODO: express opts below as auto opts @@ -1598,19 +1598,19 @@ class TestFloat4(unittest.TestCase): def test_half4_load_unrolled(self): # from llama 7B shard 4 gpus - ast = UOp(UOps.SINK, src=( - UOp(UOps.STORE, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501 - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( - UOp(UOps.CAST, dtypes.float, src=( - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.half, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501 - UOp(UOps.LOAD, dtypes.half, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 0, 1024, 1), offset=0, mask=None, contiguous=False),))),)),)),)),)),)),)) # noqa: E501 + ast = UOp(Ops.SINK, src=( + UOp(Ops.STORE, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501 + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.CAST, dtypes.float, src=( + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.half, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501 + UOp(Ops.LOAD, dtypes.half, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 0, 1024, 1), offset=0, mask=None, contiguous=False),))),)),)),)),)),)),)) # noqa: E501 # TODO: fix this, expected might change but should be positive for expected, opts in [ @@ -1627,22 +1627,22 @@ class TestFloat4(unittest.TestCase): @unittest.skip("this doesn't happen anymore") def test_float4_acc(self): # from float32 stable diffusion red tinybox - ast = UOp(UOps.SINK, src=( - UOp(UOps.STORE, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501 - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False)))),)), # noqa: E501 - UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501 - UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)) # noqa: E501 + ast = UOp(Ops.SINK, src=( + UOp(Ops.STORE, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501 + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False)))),)), # noqa: E501 + UOp(Ops.LOAD, dtypes.float, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501 + UOp(Ops.LOAD, dtypes.float, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)) # noqa: E501 for expected, opts in [ (1, [Opt(op=OptOps.UPCAST, axis=2, amt=4)]), @@ -1651,22 +1651,22 @@ class TestFloat4(unittest.TestCase): k = Kernel(ast) for opt in opts: k.apply_opt(opt) k.linearize() - count = len([uop for uop in k.uops if uop.op is UOps.DEFINE_ACC and uop.dtype == dtypes.float.vec(4)]) + count = len([uop for uop in k.uops if uop.op is Ops.DEFINE_ACC and uop.dtype == dtypes.float.vec(4)]) assert count == expected, f"{count=}, {expected=}" @unittest.skip("this doesn't happen anymore") def test_float2_acc(self): # from resnet - ast = UOp(UOps.SINK, src=( - UOp(UOps.STORE, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),))), # noqa: E501 - UOp(UOps.CAST, dtypes.half, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (4, 6)), src=( - UOp(UOps.CAST, dtypes.float, src=( - UOp(UOps.LOAD, dtypes.half, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True)))),)),)),)),)),)),)) # noqa: E501 + ast = UOp(Ops.SINK, src=( + UOp(Ops.STORE, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),))), # noqa: E501 + UOp(Ops.CAST, dtypes.half, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (4, 6)), src=( + UOp(Ops.CAST, dtypes.float, src=( + UOp(Ops.LOAD, dtypes.half, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True)))),)),)),)),)),)),)) # noqa: E501 for expected, opts in [ (16, [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=3, amt=4)]), # noqa: E501 (4, [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=2)]), @@ -1674,7 +1674,7 @@ class TestFloat4(unittest.TestCase): k = Kernel(ast) for opt in opts: k.apply_opt(opt) k.linearize() - count = len([uop for uop in k.uops if uop.op is UOps.DEFINE_ACC and uop.dtype == dtypes.float.vec(2)]) + count = len([uop for uop in k.uops if uop.op is Ops.DEFINE_ACC and uop.dtype == dtypes.float.vec(2)]) assert count == expected, f"{count=}, {expected=}" class TestHandCodedOpts(unittest.TestCase): @@ -1945,23 +1945,23 @@ class TestKernelOpts(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") def test_buf_index_not_found_tensor_core(self): - ast = UOp(UOps.SINK, src=( - UOp(UOps.STORE, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),))), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.CAST, dtypes.float, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.LOAD, dtypes.int, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501 - UOp(UOps.LOAD, dtypes.int, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501 - UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)) # noqa: E501 + ast = UOp(Ops.SINK, src=( + UOp(Ops.STORE, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),))), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.CAST, dtypes.float, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.LOAD, dtypes.int, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501 + UOp(Ops.LOAD, dtypes.int, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501 + UOp(Ops.LOAD, dtypes.float, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)) # noqa: E501 k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) with self.assertRaises(KernelOptError): k.apply_opt(Opt(OptOps.TC, 0, 1)) @@ -2135,11 +2135,11 @@ class TestKernelOpts(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_padto_group(self): Tensor.manual_seed(0) - g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] - ld0 = UOp(UOps.LOAD, dtypes.float, (g1, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501 - ld1 = UOp(UOps.LOAD, dtypes.float, (g2, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501 - store = UOp(UOps.STORE, src=(g0, ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)).to_uop(), UOp(UOps.REDUCE_AXIS, dtypes.float, (ld0*ld1,), (BinaryOps.ADD, (0, 2, 4, 6)),))) # noqa: E501 - sink = UOp(UOps.SINK, src=(store,)) + g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] + ld0 = UOp(Ops.LOAD, dtypes.float, (g1, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501 + ld1 = UOp(Ops.LOAD, dtypes.float, (g2, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501 + store = UOp(Ops.STORE, src=(g0, ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)).to_uop(), UOp(Ops.REDUCE_AXIS, dtypes.float, (ld0*ld1,), (BinaryOps.ADD, (0, 2, 4, 6)),))) # noqa: E501 + sink = UOp(Ops.SINK, src=(store,)) data1 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize() data2 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize() helper_linearizer_ast(sink, [data1, data2], opts=[ diff --git a/test/test_linearizer_dumb.py b/test/test_linearizer_dumb.py index a91b66e698..44688cc60a 100644 --- a/test/test_linearizer_dumb.py +++ b/test/test_linearizer_dumb.py @@ -5,7 +5,7 @@ import unittest from test.helpers import ast_const from tinygrad import Device, dtypes -from tinygrad.ops import UOp, UOps, BinaryOps, TernaryOps +from tinygrad.ops import UOp, Ops, BinaryOps, TernaryOps from tinygrad.helpers import getenv from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad.engine.search import Opt, OptOps @@ -14,26 +14,26 @@ from tinygrad.codegen.kernel import Kernel class TestLinearizerDumb(unittest.TestCase): @unittest.skipUnless(Device.DEFAULT == "METAL", "only tested on METAL") def test_unmerged_ifs(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MAX, src=( - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( - UOp(UOps.CAST, dtypes.half, arg=None, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 512, 4, 9, 4, 9), strides=(0, 25088, 0, 49, 0, 7, 0, 1), offset=-8, mask=((0, 1), (0, 64), (0, 1), (0, 512), (0, 4), (1, 8), (0, 4), (1, 8)), contiguous=False), View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(663552, 0, 0, 36, 1, 1296, 360, 10), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MAX, src=( + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.CAST, dtypes.half, arg=None, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 512, 4, 9, 4, 9), strides=(0, 25088, 0, 49, 0, 7, 0, 1), offset=-8, mask=((0, 1), (0, 64), (0, 1), (0, 512), (0, 4), (1, 8), (0, 4), (1, 8)), contiguous=False), View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(663552, 0, 0, 36, 1, 1296, 360, 10), offset=0, mask=None, contiguous=False))), src=()),)), + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), ast_const(dtypes.half, 0.9999950000374996, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.half, 0.0, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) opts = [Opt(op=OptOps.TC, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UNROLL, axis=1, amt=0)] k = Kernel(ast, opts=Device["METAL"].renderer) k.required_optimizations() @@ -43,88 +43,88 @@ class TestLinearizerDumb(unittest.TestCase): Device[Device.DEFAULT].compiler.compile_cached(prg.src) gate_count = len([x for x in prg.src.splitlines() if "if" in x]) assert gate_count == 1, f"must have only one gate {gate_count} != 1" - assert len([u for u in k.uops if u.op is UOps.IF]) == 1, "must have a single IF" + assert len([u for u in k.uops if u.op is Ops.IF]) == 1, "must have a single IF" @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local") def test_max_simplify_and_cancel(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=( - UOp(UOps.CAST, dtypes.int, arg=None, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.MUL, src=( + UOp(Ops.CAST, dtypes.int, arg=None, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.bool, True, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( ast_const(dtypes.int, -1, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1001, 1999), strides=(0, 0), offset=0, mask=((0, 1001), (999, 1999)), contiguous=False), View(shape=(1000, 1000), strides=(1, 2000), offset=0, mask=None, contiguous=False))), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1001, 1999), strides=(0, 0), offset=0, mask=((0, 1001), (999, 1999)), contiguous=False), View(shape=(1000, 1000), strides=(1, 2000), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, 1000, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) k.required_optimizations() for opt in opts: k.apply_opt(opt) prg = k.to_program() print(prg.src) - assert prg.uops is not None and not any(uop.op is UOps.ALU and uop.arg is BinaryOps.MAX for uop in prg.uops), "leftover MAX" + assert prg.uops is not None and not any(uop.op is Ops.ALU and uop.arg is BinaryOps.MAX for uop in prg.uops), "leftover MAX" @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local") def test_expander_new_srcs(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) k.required_optimizations() for opt in opts: k.apply_opt(opt) prg = k.to_program() print(prg.src) - if_uops = [u for u in k.uops if u.op is UOps.IF] + if_uops = [u for u in k.uops if u.op is Ops.IF] self.assertIn(len(if_uops), {1,3}) conditions = if_uops[0].src[0].sparents self.assertLessEqual(len(conditions), 9) # this was a bug in embedding, someday we should fold this anyway def test_llama_embedding(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.CAST, dtypes.half, arg=None, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( - UOp(UOps.CAST, dtypes.half, arg=None, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=( + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.CAST, dtypes.half, arg=None, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.CAST, dtypes.half, arg=None, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=( ast_const(dtypes.int, 1, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32001, 63999), strides=(0, 0), offset=0, mask=((0, 32001), (31999, 63999)), contiguous=False), View(shape=(4096, 32000, 32000), strides=(0, 1, 64000), offset=0, mask=None, contiguous=False))), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32001, 63999), strides=(0, 0), offset=0, mask=((0, 32001), (31999, 63999)), contiguous=False), View(shape=(4096, 32000, 32000), strides=(0, 1, 64000), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, -1, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.int, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.int, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.bool, True, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(1, 4096, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(1, 4096, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) prg = k.to_program() print(prg.src) @@ -132,88 +132,88 @@ class TestLinearizerDumb(unittest.TestCase): # from process replay https://github.com/tinygrad/tinygrad/actions/runs/10389229290/job/28766762085#step:18:6490 @unittest.expectedFailure def test_unaligns_idxs(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.LOAD, dtypes.long, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.CAST, dtypes.long, arg=None, src=( - UOp(UOps.LOAD, dtypes.int, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.LOAD, dtypes.long, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.CAST, dtypes.long, arg=None, src=( + UOp(Ops.LOAD, dtypes.int, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), ast_const(dtypes.bool, True, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=3)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) for opt in opts: k.apply_opt(opt) prg = k.to_program() print(prg.src) - load_idxs = [x.src[1] for x in k.uops if x.op is UOps.LOAD and x.src[0].arg == 3] + load_idxs = [x.src[1] for x in k.uops if x.op is Ops.LOAD and x.src[0].arg == 3] assert load_idxs[0] < load_idxs[1], f"first loaded idx {load_idxs[0].arg} then {load_idxs[1].arg}!" @unittest.expectedFailure @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need float4") def test_unrolled_float4_align(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=( - UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.LOAD, dtypes.long, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=( + UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.LOAD, dtypes.long, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)), ast_const(dtypes.long, -1, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.bool, True, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.float, 0.0, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) for opt in opts: k.apply_opt(opt) prg = k.to_program() print(prg.src) - load_idxs = [x.src[1] for x in k.uops if x.op is UOps.LOAD and x.src[0].arg == 2] + load_idxs = [x.src[1] for x in k.uops if x.op is Ops.LOAD and x.src[0].arg == 2] assert load_idxs[0] < load_idxs[1], f"first loaded idx {load_idxs[0].arg} then {load_idxs[1].arg}!" @unittest.expectedFailure @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need float4") @unittest.skipIf(getenv("PTX"), "this is somehow correct in PTX") def test_upcasted_stores_out_of_order(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6,)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(0, 0, 0, 0, 0, 0, 1, 0, 4, 48, 16), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6,)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(0, 0, 0, 0, 0, 0, 1, 0, 4, 48, 16), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=0)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) for opt in opts: k.apply_opt(opt) prg = k.to_program() print(prg.src) - store_idxs = [x.src[1] for x in k.uops if x.op is UOps.STORE] + store_idxs = [x.src[1] for x in k.uops if x.op is Ops.STORE] for i in range(len(store_idxs) - 1): first_bounds = store_idxs[i].vmin+store_idxs[i].vmax next_bounds = store_idxs[i+1].vmin+store_idxs[i+1].vmax diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 89e4c0be51..b3754aa7ad 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -2,7 +2,7 @@ import unittest, random import numpy as np from tinygrad.codegen.kernel import Kernel, KernelOptError -from tinygrad.ops import UOp, UOps, BinaryOps, UnaryOps, TernaryOps +from tinygrad.ops import UOp, Ops, BinaryOps, UnaryOps, TernaryOps from tinygrad.engine.search import Opt, OptOps from tinygrad import Device, dtypes, Tensor from tinygrad.helpers import CI @@ -40,609 +40,609 @@ class TestLinearizerFailures(unittest.TestCase): Tensor.manual_seed(42) def test_failure_1(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(16, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 16), strides=(16, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(16, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(16, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 16), strides=(16, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(16, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) helper_test_lin(Kernel(ast), [], failed_platforms=[]) def test_failure_2(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 2, 37, 9, 1, 1), strides=(666, 333, 9, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (4, 5)), src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 2, 111, 27), strides=(6160, 3080, 28, 1), offset=0, mask=((0, 32), (0, 2), (0, 110), (0, 27)), contiguous=False), View(shape=(32, 2, 37, 9, 2, 2), strides=(5994, 2997, 81, 3, 27, 1), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 2, 37, 9, 1, 1), strides=(666, 333, 9, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (4, 5)), src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 2, 111, 27), strides=(6160, 3080, 28, 1), offset=0, mask=((0, 32), (0, 2), (0, 110), (0, 27)), contiguous=False), View(shape=(32, 2, 37, 9, 2, 2), strides=(5994, 2997, 81, 3, 27, 1), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) opts = [Opt(op=OptOps.LOCAL, axis=0, amt=32)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_3(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 16, 1), strides=(128, 16, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 16, 16), strides=(2048, 256, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 16, 1), strides=(128, 16, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 16, 16), strides=(2048, 256, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) opts = [Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UNROLL, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=32)] # METAL: AssertionError: Error Domain=AGXMetalG13X Code=3 "Threadgroup memory size (65536) exceeds the maximum threadgroup memory allowed (32768)" UserInfo={NSLocalizedDescription=Threadgroup memory size (65536) exceeds the maximum threadgroup memory allowed (32768)} helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_5(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - x5:=UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + x5:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( ast_const(dtypes.float, 0.1464405059814453, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), ast_const(dtypes.float, 1.0, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x5,)),)),)),)) opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=0)] # EXEC_ERROR, it has no global_size helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_6(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( ast_const(dtypes.int, -1, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 10), strides=(1, 20), offset=0, mask=None, contiguous=False))), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 10), strides=(1, 20), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, 10, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=0)] # COMPILE FAILED, KeyError: UOps.CONST helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_7(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 32, 1, 34, 1, 34), strides=(36992, 1156, 0, 34, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 4)), src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 32, 6, 8, 4, 6, 8, 4), strides=(2048, 64, 6291456, 8, 0, 1048576, 1, 0), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 8), (0, 1), (0, 6), (0, 8), (0, 1)), contiguous=False), View(shape=(512, 32, 6, 35, 6, 35), strides=(1179648, 36864, 6144, 192, 32, 1), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 32), (0, 6), (0, 32)), contiguous=False), View(shape=(512, 32, 238, 238), strides=(1411200, 44100, 210, 1), offset=0, mask=((0, 512), (0, 32), (0, 210), (0, 210)), contiguous=False), View(shape=(512, 32, 7, 34, 7, 34), strides=(1812608, 56644, 8092, 238, 34, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 32, 1, 34, 1, 34), strides=(36992, 1156, 0, 34, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 4)), src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 32, 6, 8, 4, 6, 8, 4), strides=(2048, 64, 6291456, 8, 0, 1048576, 1, 0), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 8), (0, 1), (0, 6), (0, 8), (0, 1)), contiguous=False), View(shape=(512, 32, 6, 35, 6, 35), strides=(1179648, 36864, 6144, 192, 32, 1), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 32), (0, 6), (0, 32)), contiguous=False), View(shape=(512, 32, 238, 238), strides=(1411200, 44100, 210, 1), offset=0, mask=((0, 512), (0, 32), (0, 210), (0, 210)), contiguous=False), View(shape=(512, 32, 7, 34, 7, 34), strides=(1812608, 56644, 8092, 238, 34, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=0, amt=4)] # test/test_linearizer_failures.py Fatal Python error: Segmentation fault helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_8(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.float, arg=UnaryOps.SQRT, src=( - UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - x9:=UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.float, arg=UnaryOps.SQRT, src=( + UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + x9:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)), x9,)),)), ast_const(dtypes.float, 0.000244140625, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)), ast_const(dtypes.float, 1e-06, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),)) opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4)] # fatal error: bracket nesting level exceeded maximum of 256 # note: use -fbracket-depth=N to increase maximum nesting level helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_9(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 0, 0, 4500, 0, 0, 0, 0, 900, 60, 12, 4, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 3, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 4500, 0, 0, 0, 0, 0, 0, 900, 60, 12, 4, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 0, 0, 4500, 0, 0, 0, 0, 900, 60, 12, 4, 1), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 3, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 4500, 0, 0, 0, 0, 0, 0, 900, 60, 12, 4, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_10(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.ADD, (3,)), src=( - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1024, 50257), strides=(0, 0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1024, 50257), strides=(0, 0, 1, 1024), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=3, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1024, 50257), strides=(0, 0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1024, 50257), strides=(0, 0, 1, 1024), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=3, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) helper_test_lin(Kernel(ast), [], failed_platforms=[]) def test_failure_11(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 3)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True),)), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 3)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True),)), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.float, 0.0, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.float, 1.0, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( ast_const(dtypes.float, 1.0, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 3, 3, 2, 2), strides=(0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)), - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 3, 3, 2, 2), strides=(0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)), + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)), x42:=ast_const(dtypes.float, 0.0, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)), ast_const(dtypes.float, 1.0, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)), - UOp(UOps.ALU, dtypes.float, arg=UnaryOps.SQRT, src=( - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64,), strides=(1,), offset=0, mask=None, contiguous=True), View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)), + UOp(Ops.ALU, dtypes.float, arg=UnaryOps.SQRT, src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64,), strides=(1,), offset=0, mask=None, contiguous=True), View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)), ast_const(dtypes.float, 5.425347222222222e-05, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64,), strides=(0,), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64,), strides=(0,), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)), ast_const(dtypes.float, 1e-05, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64,), strides=(0,), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64,), strides=(0,), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)), x42,)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 3, 3, 2, 2), strides=(576, 9, 3, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)), - UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 3, 3, 2, 2), strides=(576, 9, 3, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=7, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 3, 3, 2, 2), strides=(576, 9, 3, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)) + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 3, 3, 2, 2), strides=(576, 9, 3, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)), + UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 3, 3, 2, 2), strides=(576, 9, 3, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=7, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 3, 3, 2, 2), strides=(576, 9, 3, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)) helper_test_lin(Kernel(ast), [], failed_platforms=[]) def test_failure_12(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - x5:=UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + x5:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), ast_const(dtypes.float, 1.0, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x5,)),)),)),)) opts = [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.GROUP, axis=0, amt=4)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @unittest.skip("found implicit expand") def test_failure_12_multireduce(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 8)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - x5:=UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - x6:=UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 8)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + x5:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + x6:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), ast_const(dtypes.float, 1.0, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x6,)), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 8)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 8)), src=( x5,)),)),)),)),)) opts = [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.GROUP, axis=0, amt=4)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) # both kernels are correct from a code standpoint, but generate different results due to precision errors (switching to float results in output matches) def test_failure_13(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.ADD, (3,)), src=( - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 51864), strides=(51864, 0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 51864), strides=(0, 0, 1, 384), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=3, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(0, 0, 1, 0), offset=19584, mask=None, contiguous=False),)), src=()),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 51864), strides=(51864, 0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 51864), strides=(0, 0, 1, 384), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=3, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(0, 0, 1, 0), offset=19584, mask=None, contiguous=False),)), src=()),)),)),)),)) opts = [Opt(op=OptOps.GROUP, axis=0, amt=4)] helper_test_lin(Kernel(ast), opts, failed_platforms=["METAL", "GPU", "CUDA"]) def test_failure_14(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - x5:=UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + x5:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), ast_const(dtypes.float, 1.0, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x5,)),)),)),)) opts = [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4)] # COMPILE_ERROR on METAL in fuzz_linearizer: unused variables and undeclared variables helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_15(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 196, 14, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5,)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 480, 1, 1), strides=(0, 0, 0, 14, 1, 196, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 480, 1, 1), strides=(0, 0, 480, 0, 0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.ALU, dtypes.float, arg=UnaryOps.SQRT, src=( - UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 196, 14, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5,)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 480, 1, 1), strides=(0, 0, 0, 14, 1, 196, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 480, 1, 1), strides=(0, 0, 480, 0, 0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.ALU, dtypes.float, arg=UnaryOps.SQRT, src=( + UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), ast_const(dtypes.float, 1e-05, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=16)] # COMPILE_ERROR on METAL in fuzz_linearizer ast 115: Error Domain=AGXMetalG14X Code=3 "Compiler encountered an internal error" helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_16(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1024), strides=(0, 1024, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1024), strides=(0, 1024, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)), ast_const(dtypes.float, 0.0009765625, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) opts = [Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=1, amt=4)] # COMPILE_ERROR on METAL/GPU (probably HIP/CUDA too) in fuzz_linearizer ast 154: bracket nesting level exceeded maximum of 256 helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_17(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 40, 1, 28, 28, 1, 1), strides=(31360, 0, 784, 0, 28, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 40, 240, 28, 28, 1, 1), strides=(0, 0, 1, 40, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 40, 240, 28, 28, 1, 1), strides=(188160, 0, 0, 784, 28, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 40, 1, 28, 28, 1, 1), strides=(31360, 0, 784, 0, 28, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 40, 240, 28, 28, 1, 1), strides=(0, 0, 1, 40, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 40, 240, 28, 28, 1, 1), strides=(188160, 0, 0, 784, 28, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.GROUPTOP, axis=0, amt=16), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.LOCAL, axis=1, amt=4)] # COMPILE_ERROR on METAL in fuzz_linearizer ast 178: Error Domain=AGXMetalG14X Code=3 "Compiler encountered an internal error" helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_18(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1536), strides=(1536, 0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1536), strides=(0, 0, 1536, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1536), strides=(1536, 0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1536), strides=(0, 0, 1536, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.GROUPTOP, axis=0, amt=256), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=3)] # COMPILE_ERROR on METAL in fuzz_linearizer ast 239: Error Domain=AGXMetalG14X Code=3 "Compiler encountered an internal error" helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_19(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 9, 7, 3, 3), strides=(2268, 0, 567, 0, 63, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 4, 9, 7, 3, 3), strides=(0, 0, 36, 9, 0, 0, -3, -1), offset=8, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 4, 9, 7, 3, 3), strides=(252, 0, 0, 63, 7, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 9, 7, 3, 3), strides=(2268, 0, 567, 0, 63, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 4, 9, 7, 3, 3), strides=(0, 0, 36, 9, 0, 0, -3, -1), offset=8, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 4, 9, 7, 3, 3), strides=(252, 0, 0, 63, 7, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=7), Opt(op=OptOps.UPCAST, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=3)] # COMPILE_ERROR on METAL in fuzz_linearizer ast 379: Error Domain=AGXMetalG14X Code=3 "Compiler encountered an internal error" helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_20(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 4), strides=(4, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 4), strides=(0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 4), strides=(4, 1), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 4), strides=(0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), ast_const(dtypes.float, 1.0, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 4), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 4), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=0)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_21(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(65, 1), offset=0, mask=None, contiguous=True),)), src=()), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(65, 1), offset=0, mask=None, contiguous=True),)), src=()), ast_const(dtypes.float, 1.0, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)) opts = [Opt(op=OptOps.PADTO, axis=0, amt=32)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) #@unittest.skipIf(Device.DEFAULT in ("LLVM", "METAL", "CLANG"), "flaky") @unittest.skip("flaky everywhere") def test_failure_22(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( x4:=ast_const(dtypes.float, 0.000244140625, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 3)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=7, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=8, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=9, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=10, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=11, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=12, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=13, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=14, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=15, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=16, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 17280, 180, 18, 1), offset=19, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)),)),)),)),)), - UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=17, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 3)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=7, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=8, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=9, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=10, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=11, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=12, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=13, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=14, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=15, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=16, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 17280, 180, 18, 1), offset=19, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)),)),)),)),)), + UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=17, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)), ast_const(dtypes.float, 2.0, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)), - x80:=UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=18, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)), + x80:=UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=18, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)), x4,)), ast_const(dtypes.float, 1e-05, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)), x80,)),)),)),)) opts = [] helper_test_lin(Kernel(ast), opts, failed_platforms=["METAL", "CUDA"]) def test_failure_23(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_24(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)) opts = [Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=8), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=2)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) # this is the cause of the GPT2 BEAM instability. bisects to PR#3530 O(n) arange attempt def test_failure_25(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( ast_const(dtypes.int, 1, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1025, 2047), strides=(0, 0), offset=0, mask=((0, 1025), (1023, 2047)), contiguous=False), View(shape=(1024, 1024), strides=(1, 2048), offset=0, mask=None, contiguous=False))), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1025, 2047), strides=(0, 0), offset=0, mask=((0, 1025), (1023, 2047)), contiguous=False), View(shape=(1024, 1024), strides=(1, 2048), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, -1, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) opts = [Opt(op=OptOps.GROUP, axis=0, amt=16), Opt(op=OptOps.UNROLL, axis=0, amt=4)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) # COMPARE_ERROR from GPT2 kernel - stems from uops.py self.simplify_phi_loops def test_failure_26(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( ast_const(dtypes.int, 1, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(129, 255), strides=(0, 0), offset=0, mask=((0, 129), (127, 255)), contiguous=False), View(shape=(128, 128), strides=(1, 256), offset=0, mask=None, contiguous=False))), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(129, 255), strides=(0, 0), offset=0, mask=((0, 129), (127, 255)), contiguous=False), View(shape=(128, 128), strides=(1, 256), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, -1, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) all_failing_opts = [ [Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.GROUPTOP, axis=0, amt=32), Opt(op=OptOps.UNROLL, axis=0, amt=0)], [Opt(op=OptOps.GROUPTOP, axis=0, amt=32), Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=4)], @@ -673,14 +673,14 @@ class TestLinearizerFailures(unittest.TestCase): # y: array([0.8687, 0.996 , 0.829 , ..., 0. , 0. , 0. ], dtype=float16) # COMPARE FAILED!! def test_failure_27(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 1), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.MAX, (3,)), src=( - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 169, 13, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 1), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.MAX, (3,)), src=( + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 169, 13, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) all_failing_opts = [ [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=7), Opt(op=OptOps.UPCAST, axis=0, amt=0)], ] @@ -688,96 +688,96 @@ class TestLinearizerFailures(unittest.TestCase): helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_28(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.bfloat16.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.bfloat16, arg=TernaryOps.WHERE, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( - x5:=UOp(UOps.CAST, dtypes.bfloat16, arg=None, src=( - UOp(UOps.LOAD, dtypes.int, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.bfloat16.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.bfloat16, arg=TernaryOps.WHERE, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( + x5:=UOp(Ops.CAST, dtypes.bfloat16, arg=None, src=( + UOp(Ops.LOAD, dtypes.int, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)), x9:=ast_const(dtypes.bfloat16, 230.0, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)), - UOp(UOps.ALU, dtypes.bfloat16, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.bfloat16, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.bfloat16, arg=BinaryOps.MUL, src=( + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)), + UOp(Ops.ALU, dtypes.bfloat16, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.bfloat16, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.bfloat16, arg=BinaryOps.MUL, src=( x5, ast_const(dtypes.bfloat16, 0.004347826086956522, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)), ast_const(dtypes.bfloat16, 0.199374800625, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)), ast_const(dtypes.bfloat16, 1.99375e-07, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)), - UOp(UOps.ALU, dtypes.bfloat16, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.bfloat16, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.bfloat16, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.bfloat16, arg=BinaryOps.ADD, src=( + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)), + UOp(Ops.ALU, dtypes.bfloat16, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.bfloat16, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.bfloat16, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.bfloat16, arg=BinaryOps.ADD, src=( x5, x9,)), ast_const(dtypes.bfloat16, 0.0012987012987012987, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)), ast_const(dtypes.bfloat16, -0.19439062499999998, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)), ast_const(dtypes.bfloat16, 0.199375, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) helper_test_lin(Kernel(ast), opts=[], failed_platforms=[]) def test_failure_29(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.CAST, dtypes.half, arg=None, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 128, 1, 64, 4, 58, 4, 58), strides=(0, 200704, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, 128), (0, 1), (0, 64), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(128, 1, 64, 56, 56, 64, 3, 3), strides=(3444736, 0, 0, 232, 1, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.CAST, dtypes.half, arg=None, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 128, 1, 64, 4, 58, 4, 58), strides=(0, 200704, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, 128), (0, 1), (0, 64), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(128, 1, 64, 56, 56, 64, 3, 3), strides=(3444736, 0, 0, 232, 1, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)), + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) opts = [Opt(op=OptOps.TC, axis=0, amt=1), Opt(op=OptOps.PADTO, axis=2, amt=32)] helper_test_lin(Kernel(ast), opts, failed_platforms=[], atol=1.0) def test_failure_30(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 12, 31, 31, 1, 1, 1), strides=(11532, 0, 961, 31, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.CAST, dtypes.half, arg=None, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 12, 31, 31, 3, 2, 2), strides=(3072, 0, 0, 32, 1, 1024, 32, 1), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 12, 31, 31, 3, 2, 2), strides=(0, 0, 12, 0, 0, 4, 2, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 12, 31, 31, 1, 1, 1), strides=(11532, 0, 961, 31, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.CAST, dtypes.half, arg=None, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 12, 31, 31, 3, 2, 2), strides=(3072, 0, 0, 32, 1, 1024, 32, 1), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 12, 31, 31, 3, 2, 2), strides=(0, 0, 12, 0, 0, 4, 2, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) opts = [Opt(op=OptOps.PADTO, axis=3, amt=32), Opt(op=OptOps.LOCAL, axis=3, amt=32), Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) # from METAL=1 fuzz_linearizer command in test.yml def test_failure_31(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 1), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( - UOp(UOps.ALU, dtypes.float, arg=UnaryOps.EXP2, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 169, 13, 1), offset=0, mask=None, contiguous=True),)), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 1), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.ALU, dtypes.float, arg=UnaryOps.EXP2, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 169, 13, 1), offset=0, mask=None, contiguous=True),)), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.float, 1.4426950408889634, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)) opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @@ -785,81 +785,81 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_32(self): # kernel from beaming resnet # Memory access fault on tinybox red - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 256, 14, 14, 1, 1, 1), strides=(50176, 0, 196, 14, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.CAST, dtypes.half, arg=None, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 256, 4, 16, 4, 16), strides=(0, 50176, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 256), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 256, 14, 14, 256, 3, 3), strides=(1048576, 0, 0, 64, 1, 4096, 1088, 17), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 256, 14, 14, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 256, 14, 14, 1, 1, 1), strides=(50176, 0, 196, 14, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.CAST, dtypes.half, arg=None, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 256, 4, 16, 4, 16), strides=(0, 50176, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 256), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 256, 14, 14, 256, 3, 3), strides=(1048576, 0, 0, 64, 1, 4096, 1088, 17), offset=0, mask=None, contiguous=False))), src=()),)), + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 256, 14, 14, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) opts = [Opt(op=OptOps.TC, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=7), Opt(op=OptOps.UNROLL, axis=1, amt=0), Opt(op=OptOps.LOCAL, axis=1, amt=16)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[], atol=0.1, rtol=0.05) def test_failure_33(self): # UOps.UNMUL left after linearize - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - x5:=UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(1,), offset=0, mask=((0, 26040),), contiguous=False),)), src=()),)), - UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + x5:=UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(1,), offset=0, mask=((0, 26040),), contiguous=False),)), src=()),)), + UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( x5, x10:=ast_const(dtypes.float, 0.0, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( ast_const(dtypes.float, 0.06788442333021306, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=((0, 26040),), contiguous=False),)), src=()),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=((0, 26040),), contiguous=False),)), src=()),)), x5,)), ast_const(dtypes.float, -0.03394221166510653, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=((0, 26040),), contiguous=False),)), src=()),)),)), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(1,), offset=-26040, mask=((26040, 32640),), contiguous=False),)), src=()),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=((0, 26040),), contiguous=False),)), src=()),)),)), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(1,), offset=-26040, mask=((26040, 32640),), contiguous=False),)), src=()),)), ast_const(dtypes.float, -0.18257418583505536, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=((26040, 32640),), contiguous=False),)), src=()),)),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=((26040, 32640),), contiguous=False),)), src=()),)),)),)), x10,)), ast_const(dtypes.float, -1.0, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=None, contiguous=False),)), src=()),)), ast_const(dtypes.float, 1.0, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=None, contiguous=False),)), src=()),)),)), x10,)),)),)),)),)) opts = [Opt(op=OptOps.GROUPTOP, axis=0, amt=16)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) # from fuzzing on metal def test_failure_34(self, unroll=False): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 1, 1), strides=(180, 0, 30, 3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 2, 5), strides=(77, 0, 0, 7, 1, 0, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 2, 5), strides=(0, 0, 10, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 1, 1), strides=(180, 0, 30, 3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 2, 5), strides=(77, 0, 0, 7, 1, 0, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 2, 5), strides=(0, 0, 10, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), ast_const(dtypes.float, 0.0, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) opts = [Opt(op=OptOps.TC, axis=0, amt=2), Opt(op=OptOps.UNROLL, axis=0, amt=0)] if unroll else [Opt(op=OptOps.TC, axis=0, amt=2)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @@ -868,18 +868,18 @@ class TestLinearizerFailures(unittest.TestCase): # from world fuzz_linearizer: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=1 FUZZ_N=100 FUZZ_NTH=84 python3 ./test/external/fuzz_linearizer.py def test_failure_36(self): # UOps.UNMUL left after linearize - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(5, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.CAST, dtypes.uchar, arg=None, src=( - UOp(UOps.ALU, dtypes.uint, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.uint, arg=(BinaryOps.ADD, (1,)), src=( - UOp(UOps.CAST, dtypes.uint, arg=None, src=( + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(5, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.CAST, dtypes.uchar, arg=None, src=( + UOp(Ops.ALU, dtypes.uint, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.uint, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.CAST, dtypes.uint, arg=None, src=( ast_const(dtypes.uchar, 1, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(6, 9), strides=(0, 0), offset=0, mask=((0, 6), (4, 9)), contiguous=False), View(shape=(5, 5), strides=(1, 10), offset=0, mask=None, contiguous=False))), src=()),)),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(6, 9), strides=(0, 0), offset=0, mask=((0, 6), (4, 9)), contiguous=False), View(shape=(5, 5), strides=(1, 10), offset=0, mask=None, contiguous=False))), src=()),)),)),)), ast_const(dtypes.uint, -1, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(5, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(5, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=0, amt=0)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @@ -888,26 +888,26 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_37(self): # beautiful mnist kernel number 28: 6 possible TC axis_choices (3 for axis_buf1 and 2 reduce) and all fail # fuzz: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=1 FUZZ_NTH=28 DEBUG=2 python3 ./test/external/fuzz_linearizer.py --logfile /tmp/beautiful_mnist.kernels.txt - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.LOAD, dtypes.uchar, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 5, 5), strides=(784, 0, 0, 28, 1, 0, 28, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 5, 5), strides=(0, 0, 25, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.LOAD, dtypes.uchar, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 5, 5), strides=(784, 0, 0, 28, 1, 0, 28, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 5, 5), strides=(0, 0, 25, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.float, 0.0, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) for axis in [0,1,2,3,4,5]: opts = [Opt(op=OptOps.TC, axis=axis, amt=2)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @@ -915,19 +915,19 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_38(self): # beautiful mnist kernel number 87: 6 possible TC axis_choices (2 for axis_buf1 and 3 reduce) and first/second reduce axis fail for both axis_buf1 choices # fuzz: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=1 FUZZ_NTH=87 DEBUG=2 python3 ./test/external/fuzz_linearizer.py --logfile /tmp/beautiful_mnist.kernels.txt - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 32, 1, 1, 1, 5, 5, 256), strides=(0, 0, 6400, 0, 0, 0, 1280, 256, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 3, 4)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.LOAD, dtypes.uchar, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 32, 24, 24, 1, 5, 5, 256), strides=(784, 0, 0, 28, 1, 0, 28, 1, 1568), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 32, 24, 24, 1, 5, 5, 256), strides=(18432, 0, 576, 24, 1, 0, 0, 0, 36864), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 32, 1, 1, 1, 5, 5, 256), strides=(0, 0, 6400, 0, 0, 0, 1280, 256, 1), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 3, 4)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.LOAD, dtypes.uchar, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 32, 24, 24, 1, 5, 5, 256), strides=(784, 0, 0, 28, 1, 0, 28, 1, 1568), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 32, 24, 24, 1, 5, 5, 256), strides=(18432, 0, 576, 24, 1, 0, 0, 0, 36864), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) for axis in [0,1,3,4]: opts = [Opt(op=OptOps.TC, axis=axis, amt=2)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @@ -936,26 +936,26 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_39(self): # beautiful mnist kernel number 127: 6 possible TC axis_choices (3 for axis_buf1 and 2 reduce) and all fail # fuzz: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=1 FUZZ_NTH=127 DEBUG=2 python3 ./test/external/fuzz_linearizer.py --logfile /tmp/beautiful_mnist.kernels.txt - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.LOAD, dtypes.uchar, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 5, 5), strides=(784, 0, 0, 28, 1, 0, 28, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 5, 5), strides=(0, 0, 25, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.LOAD, dtypes.uchar, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 5, 5), strides=(784, 0, 0, 28, 1, 0, 28, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 5, 5), strides=(0, 0, 25, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.float, 0.0, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) for axis in [0,1,2,3,4,5]: opts = [Opt(op=OptOps.TC, axis=axis, amt=2)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @@ -963,16 +963,16 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_40(self): # beautiful mnist kernel number 3: # fuzz: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=2 DEBUG=2 FUZZ_NTH=3 python3 ./test/external/fuzz_linearizer.py --logfile /tmp/beautiful_mnist.kernels.txt - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( ast_const(dtypes.int, 1, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, -1, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) for amt in [16,32]: opts = [Opt(op=OptOps.GROUPTOP, axis=0, amt=amt), Opt(op=OptOps.UNROLL, axis=0, amt=0)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @@ -981,20 +981,20 @@ class TestLinearizerFailures(unittest.TestCase): @unittest.skipIf(CI, "for real AMD GPU") def test_failure_41(self): # One more resnet crash with a page fault on AMD. Checked on rocm6.1.3, -O1 works, -O2 fails - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 1, 1, 1), strides=(100352, 0, 784, 28, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.CAST, dtypes.half, arg=None, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 128, 4, 58, 4, 58), strides=(0, 401408, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, 256), (0, 1), (0, 128), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(256, 1, 128, 28, 28, 128, 3, 3), strides=(6889472, 0, 0, 464, 2, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 1, 1, 1), strides=(100352, 0, 784, 28, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.CAST, dtypes.half, arg=None, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 128, 4, 58, 4, 58), strides=(0, 401408, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, 256), (0, 1), (0, 128), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(256, 1, 128, 28, 28, 128, 3, 3), strides=(6889472, 0, 0, 464, 2, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)), + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) opts=[Opt(op=OptOps.TC, axis=5, amt=2), Opt(op=OptOps.UNROLL, axis=0, amt=0)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["AMD", "HIP"]) @@ -1002,84 +1002,84 @@ class TestLinearizerFailures(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test needs local") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test needs shared") def test_failure_42(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.PADTO, axis=0, amt=32)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test needs local") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test needs shared") def test_failure_43(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test needs local") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test needs shared") def test_failure_44(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4)] k = helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) assert k is not None - ifs = [u for u in k.uops if u.op is UOps.IF] + ifs = [u for u in k.uops if u.op is Ops.IF] self.assertEqual(len(ifs), 4) #for st in k.uops.sink.src: self.assertEqual(len(st.src), 4) self.assertLessEqual(len(ifs[0].src[0].sparents), 17) def test_failure_45(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 1, 1, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 3, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.LOAD, dtypes.int, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (4,)), src=( + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 1, 1, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 3, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.LOAD, dtypes.int, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (4,)), src=( ast_const(dtypes.int, 1, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 3), strides=(0, 0), offset=0, mask=((0, 3), (1, 3)), contiguous=False), View(shape=(2, 3, 2, 3, 3), strides=(0, 0, 1, 0, 4), offset=0, mask=((0, 2), (0, 3), (0, 2), (0, 3), (0, 2)), contiguous=False))), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 3), strides=(0, 0), offset=0, mask=((0, 3), (1, 3)), contiguous=False), View(shape=(2, 3, 2, 3, 3), strides=(0, 0, 1, 0, 4), offset=0, mask=((0, 2), (0, 3), (0, 2), (0, 3), (0, 2)), contiguous=False))), src=()),)),)), x19:=ast_const(dtypes.int, -1, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), x21:=ast_const(dtypes.bool, True, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.LOAD, dtypes.int, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=3, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (4,)), src=( + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.LOAD, dtypes.int, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=3, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (4,)), src=( ast_const(dtypes.int, 1, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5), strides=(0, 0), offset=0, mask=((0, 4), (2, 5)), contiguous=False), View(shape=(2, 3, 2, 3, 3), strides=(0, 0, 0, 1, 6), offset=0, mask=None, contiguous=False))), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5), strides=(0, 0), offset=0, mask=((0, 4), (2, 5)), contiguous=False), View(shape=(2, 3, 2, 3, 3), strides=(0, 0, 0, 1, 6), offset=0, mask=None, contiguous=False))), src=()),)),)), x19,)),)), x21,)),)),)),)),)),)),)) # ValueError: size mismatched, can't reshape self.shape=(6, 2, 3, 3) -> new_shape=(6, 2, 3, 1, 2) @@ -1087,153 +1087,153 @@ class TestLinearizerFailures(unittest.TestCase): helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) def test_failure_46(self): - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.LOAD, dtypes.int, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 10), strides=(0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.LOAD, dtypes.int, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 10), strides=(1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.LOAD, dtypes.int, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 10), strides=(0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.LOAD, dtypes.int, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 10), strides=(1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.bool, True, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 10), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.bool, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=3, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 10), strides=(1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 10), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), - UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 10), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.bool, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=3, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 10), strides=(1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 10), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), + UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=0, amt=2)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) def test_failure_47(self): # upcast an arange, failed with UOP_IS_SYMBOLIC=1 (fixed!) - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( ast_const(dtypes.int, 1, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))), src=()),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, -1, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=0, amt=3)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @unittest.skipUnless(not CI and Device.DEFAULT in ("NV", "CUDA"), "for real NV") def test_failure_48(self): # with UOP_IS_SYMBOLIC=1, generates the wrong IDIV (fixed!) - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 1, 1, 256, 1, 1, 256), strides=(0, 0, 65536, 0, 0, 256, 0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3, 4)), src=( - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 56, 56, 256, 1, 1, 256), strides=(0, 0, 0, 56, 1, 3136, 0, 0, 802816), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 56, 56, 256, 1, 1, 256), strides=(0, 0, 3136, 56, 1, 0, 0, 0, 200704), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 1, 1, 256, 1, 1, 256), strides=(0, 0, 65536, 0, 0, 256, 0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3, 4)), src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 56, 56, 256, 1, 1, 256), strides=(0, 0, 0, 56, 1, 3136, 0, 0, 802816), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 56, 56, 256, 1, 1, 256), strides=(0, 0, 3136, 56, 1, 0, 0, 0, 200704), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)) opts = [Opt(op=OptOps.TC, axis=0, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) def test_failure_49(self): # with UOP_IS_SYMBOLIC=1, on METAL it breaks store fusion and has A+B and B+A being two different UOp - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 1), strides=(6, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 10), strides=(10, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 10), strides=(0, 1, 6), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 1), strides=(6, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 10), strides=(10, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 10), strides=(0, 1, 6), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.TC, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) def test_failure_50(self): # from BEAM_COMPARE=2 running tinyphysics.onnx model - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 20, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.REDUCE_AXIS, dtypes.bool, arg=(BinaryOps.ADD, (3,)), src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.bool, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 20, 1), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.LOAD, dtypes.int, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.LOAD, dtypes.int, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=3, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 20, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.REDUCE_AXIS, dtypes.bool, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.bool, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 20, 1), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.LOAD, dtypes.int, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.LOAD, dtypes.int, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=3, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.bool, True, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)), ast_const(dtypes.bool, True, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=1, amt=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) def test_failure_51(self): # regression test for #7019, training bert on tinybox red - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(1024, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.half, arg=UnaryOps.RECIP, src=( - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.half, arg=TernaryOps.WHERE, src=( - x6:=UOp(UOps.VALID, dtypes.bool, arg=None, src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.CONST, dtypes.half, arg=1.0, src=()), - x9:=UOp(UOps.CONST, dtypes.half, arg=0.0, src=()),)), - UOp(UOps.ALU, dtypes.half, arg=UnaryOps.EXP2, src=( - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.half, arg=TernaryOps.WHERE, src=( + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(1024, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.half, arg=UnaryOps.RECIP, src=( + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.half, arg=TernaryOps.WHERE, src=( + x6:=UOp(Ops.VALID, dtypes.bool, arg=None, src=( + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.CONST, dtypes.half, arg=1.0, src=()), + x9:=UOp(Ops.CONST, dtypes.half, arg=0.0, src=()),)), + UOp(Ops.ALU, dtypes.half, arg=UnaryOps.EXP2, src=( + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.half, arg=TernaryOps.WHERE, src=( x6, - UOp(UOps.CONST, dtypes.half, arg=2.0, src=()), + UOp(Ops.CONST, dtypes.half, arg=2.0, src=()), x9,)), - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.ADD, src=( - UOp(UOps.CAST, dtypes.half, arg=None, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1024), strides=(524288, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1024), strides=(0, 1024, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=3, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), - UOp(UOps.ALU, dtypes.half, arg=TernaryOps.WHERE, src=( + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.ADD, src=( + UOp(Ops.CAST, dtypes.half, arg=None, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1024), strides=(524288, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1024), strides=(0, 1024, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=3, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), + UOp(Ops.ALU, dtypes.half, arg=TernaryOps.WHERE, src=( x6, - UOp(UOps.CONST, dtypes.half, arg=-1.4426950408889634, src=()), + UOp(Ops.CONST, dtypes.half, arg=-1.4426950408889634, src=()), x9,)),)),)),)),)),)),)) opts = [Opt(op=OptOps.TC, axis=0, amt=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) @@ -1243,56 +1243,56 @@ class TestLinearizerFailures(unittest.TestCase): # resnet beam. # NV also fails with a pf. # CUDA Error 700, an illegal memory access was encountered - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.CAST, dtypes.half, arg=None, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, 256), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(256, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.CAST, dtypes.half, arg=None, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, 256), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(256, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))), src=()),)), + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) opts = [Opt(op=OptOps.TC, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) def test_failure_53(self): # COMPILE_ERROR, val scope issue - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.uchar, arg=(BinaryOps.ADD, (1,)), src=( - UOp(UOps.ALU, dtypes.uchar, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.uchar, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.CAST, dtypes.uchar, arg=None, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( - UOp(UOps.LOAD, dtypes.int, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=( - UOp(UOps.ALU, dtypes.int, arg=TernaryOps.WHERE, src=( - UOp(UOps.VALID, dtypes.bool, arg=None, src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(50001, 99999), strides=(0, 0), offset=0, mask=((0, 50001), (49999, 99999)), contiguous=False), View(shape=(1024, 50000, 50000), strides=(0, 1, 100000), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(UOps.CONST, dtypes.int, arg=1, src=()), - x20:=UOp(UOps.CONST, dtypes.int, arg=0, src=()),)),)), - UOp(UOps.ALU, dtypes.int, arg=TernaryOps.WHERE, src=( - x22:=UOp(UOps.VALID, dtypes.bool, arg=None, src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.CONST, dtypes.int, arg=-1, src=()), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.uchar, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.ALU, dtypes.uchar, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.uchar, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.CAST, dtypes.uchar, arg=None, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( + UOp(Ops.LOAD, dtypes.int, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.ALU, dtypes.int, arg=TernaryOps.WHERE, src=( + UOp(Ops.VALID, dtypes.bool, arg=None, src=( + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(50001, 99999), strides=(0, 0), offset=0, mask=((0, 50001), (49999, 99999)), contiguous=False), View(shape=(1024, 50000, 50000), strides=(0, 1, 100000), offset=0, mask=None, contiguous=False))), src=()),)), + UOp(Ops.CONST, dtypes.int, arg=1, src=()), + x20:=UOp(Ops.CONST, dtypes.int, arg=0, src=()),)),)), + UOp(Ops.ALU, dtypes.int, arg=TernaryOps.WHERE, src=( + x22:=UOp(Ops.VALID, dtypes.bool, arg=None, src=( + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.CONST, dtypes.int, arg=-1, src=()), x20,)),)),)), - UOp(UOps.ALU, dtypes.bool, arg=TernaryOps.WHERE, src=( + UOp(Ops.ALU, dtypes.bool, arg=TernaryOps.WHERE, src=( x22, - UOp(UOps.CONST, dtypes.bool, arg=True, src=()), - UOp(UOps.CONST, dtypes.bool, arg=False, src=()),)),)),)),)),)),)),)) + UOp(Ops.CONST, dtypes.bool, arg=True, src=()), + UOp(Ops.CONST, dtypes.bool, arg=False, src=()),)),)),)),)),)),)),)) opts = [Opt(op=OptOps.GROUPTOP, axis=1, amt=16)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["AMD", "GPU", "METAL", "NV", "CUDA"]) @@ -1300,41 +1300,41 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_54(self): # resnet beam # HIP: Memory access fault by GPU node-1 (Agent handle: 0x56c21f1d1480) on address 0x730cc242e000. Reason: Page not present or supervisor privilege. - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.CAST, dtypes.half, arg=None, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 4, 58, 4, 58), strides=(0, 200704, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, 256), (0, 1), (0, 64), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(3444736, 0, 0, 232, 1, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.CAST, dtypes.half, arg=None, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 4, 58, 4, 58), strides=(0, 200704, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, 256), (0, 1), (0, 64), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(3444736, 0, 0, 232, 1, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)), + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) opts = [Opt(op=OptOps.TC, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=7), Opt(op=OptOps.UPCAST, axis=1, amt=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["HIP", "AMD"]) @unittest.skipIf(CI and Device.DEFAULT in {"METAL"}, "hangs metal gpu CI") def test_failure_55(self): W = 2 - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(W, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.CAST, dtypes.half, arg=None, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, W, 1, 64, 4, 58, 4, 58), strides=(0, 200704, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, W), (0, 1), (0, 64), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(W, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.CAST, dtypes.half, arg=None, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, W, 1, 64, 4, 58, 4, 58), strides=(0, 200704, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, W), (0, 1), (0, 64), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(W, 1, 64, 56, 56, 64, 3, 3), strides=(3444736, 0, 0, 232, 1, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(W, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(W, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) opts = [Opt(op=OptOps.SWAP, axis=1, amt=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) diff --git a/test/test_linearizer_overflows.py b/test/test_linearizer_overflows.py index 37f0f30430..5c152fec11 100644 --- a/test/test_linearizer_overflows.py +++ b/test/test_linearizer_overflows.py @@ -8,7 +8,7 @@ from tinygrad.engine.search import Opt, OptOps from tinygrad.engine.search import time_linearizer, bufs_from_lin # stuff needed to unpack a kernel -from tinygrad.ops import UOp, UOps, BinaryOps, UnaryOps +from tinygrad.ops import UOp, Ops, BinaryOps, UnaryOps from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View @@ -24,143 +24,143 @@ def _test_overflow(ast, opts): @unittest.skip("unneeded without launch bounds") class TestLinearizerOverflow(unittest.TestCase): def test_overflow_1(self): - ast = UOp(UOps.SINK, None, arg=None, src=( - UOp(UOps.STORE, None, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, 64), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), + ast = UOp(Ops.SINK, None, arg=None, src=( + UOp(Ops.STORE, None, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, 64), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), x16:=ast_const(dtypes.float, 0.0, st_src=( - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.ALU, dtypes.float, arg=UnaryOps.SQRT, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.ALU, dtypes.float, arg=UnaryOps.SQRT, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( x23:=ast_const(dtypes.float, 1.0, st_src=( - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), - UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( x23, ast_const(dtypes.float, 1e-05, st_src=( - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x16,)),)),)) opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=0)] _test_overflow(ast, opts) # From BEAM on hlb_cifar.py def test_overflow_2(self): - ast = UOp(UOps.SINK, None, arg=None, src=( - UOp(UOps.STORE, None, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 1, 1, 1), strides=(65536, 0, 1024, 32, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 512, 1, 32, 4, 34, 4, 34), strides=(0, 32768, 0, 1024, 0, 32, 0, 1), offset=-33, mask=((0, 1), (0, 512), (0, 1), (0, 32), (0, 4), (1, 33), (0, 4), (1, 33)), contiguous=False), View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(591872, 0, 0, 136, 1, 18496, 4760, 35), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(0, 0, 288, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) + ast = UOp(Ops.SINK, None, arg=None, src=( + UOp(Ops.STORE, None, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 1, 1, 1), strides=(65536, 0, 1024, 32, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 512, 1, 32, 4, 34, 4, 34), strides=(0, 32768, 0, 1024, 0, 32, 0, 1), offset=-33, mask=((0, 1), (0, 512), (0, 1), (0, 32), (0, 4), (1, 33), (0, 4), (1, 33)), contiguous=False), View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(591872, 0, 0, 136, 1, 18496, 4760, 35), offset=0, mask=None, contiguous=False))), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(0, 0, 288, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=0)] _test_overflow(ast, opts) # from BEAM on default simple_conv.py (which is quite large): def test_overflow_3(self): - ast = UOp(UOps.SINK, None, arg=None, src=( - UOp(UOps.STORE, None, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 16, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 16), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) + ast = UOp(Ops.SINK, None, arg=None, src=( + UOp(Ops.STORE, None, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 16, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 16), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=2)] _test_overflow(ast, opts) # from BEAM on BS=4 simple_conv.py: def test_overflow_4(self): - ast = UOp(UOps.SINK, None, arg=None, src=( - UOp(UOps.STORE, None, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 4, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 4), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) + ast = UOp(Ops.SINK, None, arg=None, src=( + UOp(Ops.STORE, None, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 4, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 4), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)] _test_overflow(ast, opts) # from BEAM on BS=2 simple_conv.py: def test_overflow_5(self): - ast = UOp(UOps.SINK, None, arg=None, src=( - UOp(UOps.STORE, None, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 2), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) + ast = UOp(Ops.SINK, None, arg=None, src=( + UOp(Ops.STORE, None, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 2), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=2)] _test_overflow(ast, opts) # from BEAM on BS=3 simple_conv.py: def test_overflow_6(self): - ast = UOp(UOps.SINK, None, arg=None, src=( - UOp(UOps.STORE, None, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) + ast = UOp(Ops.SINK, None, arg=None, src=( + UOp(Ops.STORE, None, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)] _test_overflow(ast, opts) # from BEAM on BS=3 simple_conv.py: (alt) def test_overflow_7(self): - ast = UOp(UOps.SINK, None, arg=None, src=( - UOp(UOps.STORE, None, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) + ast = UOp(Ops.SINK, None, arg=None, src=( + UOp(Ops.STORE, None, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)] _test_overflow(ast, opts) @@ -169,26 +169,26 @@ class TestLinearizerOverflow(unittest.TestCase): class TestLinearizerOverflowAlt(unittest.TestCase): def test_overflow_1(self): BS = 2 - g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] + g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] in_st_1 = ShapeTracker(views=(View(shape=(1, BS, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, BS), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))).to_uop() in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop() ot_st = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)).to_uop() - prod = UOp(UOps.LOAD, dtypes.float, (g1, in_st_1)) * UOp(UOps.LOAD, dtypes.float, (g2, in_st_2)) - store = UOp(UOps.STORE, src=(g0, ot_st, UOp(UOps.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5))))) - ast = UOp(UOps.SINK, src=(store,)) + prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2)) + store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5))))) + ast = UOp(Ops.SINK, src=(store,)) opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=2)] _test_overflow(ast, opts) def test_overflow_2(self): BS = 2 - g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] + g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] in_st_1 = ShapeTracker(views=(View(shape=(1, BS, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, BS), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))).to_uop() in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop() ot_st = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)).to_uop() - prod = UOp(UOps.LOAD, dtypes.float, (g1, in_st_1)) * UOp(UOps.LOAD, dtypes.float, (g2, in_st_2)) - store = UOp(UOps.STORE, src=(g0, ot_st, UOp(UOps.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5))))) - ast = UOp(UOps.SINK, src=(store,)) + prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2)) + store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5))))) + ast = UOp(Ops.SINK, src=(store,)) opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=4, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=5, amt=2)] _test_overflow(ast, opts) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 54233b0f2c..0cca20d267 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -1,7 +1,7 @@ import unittest, functools, random from typing import List from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes -from tinygrad.ops import MetaOps, ReduceOps, BinaryOps, UOps +from tinygrad.ops import MetaOps, ReduceOps, BinaryOps, Ops from tinygrad.helpers import CI, getenv, prod, Context from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.engine.schedule import create_schedule @@ -617,24 +617,24 @@ class TestMultiTensor(unittest.TestCase): t = t + 1 for si in t.schedule(): ast = si.ast.src[0] - assert ast.op is UOps.STORE + assert ast.op is Ops.STORE assert ast.src[2].arg is BinaryOps.ADD - assert ast.src[2].src[0].op is UOps.LOAD - assert ast.src[2].src[1].src[1].op is UOps.CONST and ast.src[2].src[1].src[1].arg == 1 + assert ast.src[2].src[0].op is Ops.LOAD + assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 1 t = 2 * t for si in t.schedule(): ast = si.ast.src[0] - assert ast.op is UOps.STORE + assert ast.op is Ops.STORE assert ast.src[2].arg is BinaryOps.MUL - assert ast.src[2].src[0].src[1].op is UOps.CONST and ast.src[2].src[0].src[1].arg == 2 - assert ast.src[2].src[1].op is UOps.LOAD + assert ast.src[2].src[0].src[1].op is Ops.CONST and ast.src[2].src[0].src[1].arg == 2 + assert ast.src[2].src[1].op is Ops.LOAD t = t + t.full_like(3) for si in t.schedule(): ast = si.ast.src[0] - assert ast.op is UOps.STORE + assert ast.op is Ops.STORE assert ast.src[2].arg is BinaryOps.ADD - assert ast.src[2].src[0].op is UOps.LOAD - assert ast.src[2].src[1].src[1].op is UOps.CONST and ast.src[2].src[1].src[1].arg == 3 + assert ast.src[2].src[0].op is Ops.LOAD + assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 3 def test_shard_memory(self): devices = (d0, d1, d2, d3) diff --git a/test/test_nn.py b/test/test_nn.py index 4b2d83dd68..e5c241c544 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -3,7 +3,7 @@ import unittest import numpy as np import torch from tinygrad import Tensor, Device, TinyJit -from tinygrad.ops import UOps +from tinygrad.ops import Ops from tinygrad.helpers import CI, Context from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm, LSTMCell @@ -517,7 +517,7 @@ class TestNN(unittest.TestCase): [12, 19, 8, 1]]) result = layer(a) schedule = create_schedule([result.lazydata]) - self.assertEqual(3, len([item for item in schedule if item.ast.op is UOps.SINK]), "first run realizes arange, weight, and embedding") + self.assertEqual(3, len([item for item in schedule if item.ast.op is Ops.SINK]), "first run realizes arange, weight, and embedding") run_schedule(schedule) b = Tensor([[1, 2, 3], @@ -525,7 +525,7 @@ class TestNN(unittest.TestCase): [7, 8, 9]]) result = layer(b) schedule = create_schedule([result.lazydata]) - self.assertEqual(1, len([item for item in schedule if item.ast.op is UOps.SINK]), "second run realizes embedding only") + self.assertEqual(1, len([item for item in schedule if item.ast.op is Ops.SINK]), "second run realizes embedding only") run_schedule(schedule) def test_embedding_shape(self): diff --git a/test/test_renderer_failures.py b/test/test_renderer_failures.py index d17966295b..77975d6665 100644 --- a/test/test_renderer_failures.py +++ b/test/test_renderer_failures.py @@ -8,7 +8,7 @@ from tinygrad.dtype import dtypes from tinygrad.engine.realize import CompiledRunner from tinygrad.helpers import dedup, flatten, prod from tinygrad.renderer.cstyle import CStyleLanguage -from tinygrad.ops import BinaryOps, UOp, UOps +from tinygrad.ops import BinaryOps, UOp, Ops from tinygrad.renderer import Program from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.engine.lazy import LazyBuffer @@ -20,7 +20,7 @@ def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None): def _recursive_add(uop:UOp) -> List[UOp]: return flatten([_recursive_add(x) for x in uop.src])+[uop] uops = dedup(flatten(_recursive_add(st) for st in stores)) outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \ - initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is UOps.STORE] + initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE] inbufs = [cast(LazyBuffer,x.lazydata).base.buffer for x in inputs] src = Device[Device.DEFAULT].renderer.render("test", uops) ei = CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops, local_size=local_size)) @@ -30,13 +30,13 @@ def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None): @unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, CStyleLanguage), "uops are for cstyle") class TestCStyleFailures(unittest.TestCase): def test_inline_const_alu(self): - a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) - b = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1) + a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) + b = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1) idx = UOp.const(dtypes.int, 0) - ld = UOp(UOps.LOAD, dtypes.int, (b, idx)) + ld = UOp(Ops.LOAD, dtypes.int, (b, idx)) alu = ld.alu(BinaryOps.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1)) store = UOp.store(a, idx, alu) - sink = UOp(UOps.SINK, dtypes.void, (store,)) + sink = UOp(Ops.SINK, dtypes.void, (store,)) uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer)) # CLANG doesn't use the max function ret = _test_uop_result([Tensor([1])], uops)[0] @@ -45,21 +45,21 @@ class TestCStyleFailures(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local and Device.DEFAULT == "PTX", "need local") class TestPTXFailures(unittest.TestCase): def test_gated_store_with_alu(self): - a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) - gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0) - gated_alu_store = UOp(UOps.STORE, dtypes.void, (a, lidx0, UOp.const(dtypes.int, 1), gate_alu)) - sink = UOp(UOps.SINK, dtypes.void, (gated_alu_store,)) + a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) + gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0) + gated_alu_store = UOp(Ops.STORE, dtypes.void, (a, lidx0, UOp.const(dtypes.int, 1), gate_alu)) + sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,)) uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer)) ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0] np.testing.assert_equal(ret, [0, 1, 1, 1]) def test_gated_store_with_if(self): - a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) - gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0) + a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) + gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0) val = UOp.const(dtypes.int, 1) - if_uop = UOp(UOps.IF, dtypes.void, (gate_alu,)) - gated_alu_store = UOp(UOps.STORE, dtypes.void, (a, lidx0, val, if_uop)) - sink = UOp(UOps.SINK, dtypes.void, (gated_alu_store,)) + if_uop = UOp(Ops.IF, dtypes.void, (gate_alu,)) + gated_alu_store = UOp(Ops.STORE, dtypes.void, (a, lidx0, val, if_uop)) + sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,)) uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer)) ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0] np.testing.assert_equal(ret, [0, 1, 1, 1]) diff --git a/test/test_schedule.py b/test/test_schedule.py index 21916bdf52..0c39d13a10 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -12,7 +12,7 @@ from tinygrad import nn, dtypes, Device, Tensor from tinygrad.dtype import DType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View -from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite, track_rewrites +from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, Ops, graph_rewrite, track_rewrites from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context from tinygrad.codegen.kernel import Kernel, verify_ast from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left @@ -29,7 +29,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], LazyBuffer], allowed:int, to_pr if to_prerealize: for pre in to_prerealize: pre.schedule() sched = create_schedule(outs) - if filter_sink: sched = [s for s in sched if s.ast.op is UOps.SINK] + if filter_sink: sched = [s for s in sched if s.ast.op is Ops.SINK] if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}") if len(sched) != allowed or DEBUG >= 3: for i, s in enumerate(sched): @@ -38,7 +38,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], LazyBuffer], allowed:int, to_pr if len(sched) != allowed: raise KernelCountException(f"{len(sched)=} != {allowed}") # test the (sink) ops linearize for s in sched: - if s.ast.op is not UOps.SINK: continue + if s.ast.op is not Ops.SINK: continue l = Kernel(s.ast) l.hand_coded_optimizations() l.linearize() @@ -58,7 +58,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs): dtypes.default_float = old_default_float with Context(**kwargs): s = create_schedule([ret.lazydata, img.grad.lazydata, w.grad.lazydata]) run_schedule(s.copy()) - cnt = len([si for si in s if si.ast.op is UOps.SINK]) + cnt = len([si for si in s if si.ast.op is Ops.SINK]) assert cnt == allowed, f"expected {allowed} kernels, got {cnt}" if getenv("CHECK", 1): import torch @@ -191,7 +191,7 @@ class TestSchedule(unittest.TestCase): r1 = (x - r0).sum(axis=0).div(2) out = r0 + r1 schedule = check_schedule(out, 2) - reduceops = [x for si in schedule for x in si.ast.parents if x.op is UOps.REDUCE_AXIS] + reduceops = [x for si in schedule for x in si.ast.parents if x.op is Ops.REDUCE_AXIS] assert len(reduceops) == 2 def test_cache_reduce_multiple_children(self): @@ -202,7 +202,7 @@ class TestSchedule(unittest.TestCase): out0 = r0 + y out1 = r1 + y schedule = check_schedule([out0, out1], 4) - reduceops = [x for si in schedule for x in si.ast.parents if x.op is UOps.REDUCE_AXIS] + reduceops = [x for si in schedule for x in si.ast.parents if x.op is Ops.REDUCE_AXIS] assert len(reduceops) == 2 def test_fold_double_unary(self): @@ -1108,7 +1108,7 @@ class TestSchedule(unittest.TestCase): a = Tensor.empty(16, 16) b = (a.sum(0) + a.max(1)) + 2 schedule = check_schedule(b, 2) - self.assertIs(schedule[0].ast.src[0].src[2].op, UOps.REDUCE_AXIS) + self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.REDUCE_AXIS) # multireduce spec def test_multireduce_midreduce_nochase(self): @@ -1117,7 +1117,7 @@ class TestSchedule(unittest.TestCase): b = (a.sum(0)+a.max(0) + a.max(1)+a.sum(1)) + 2 # schedule = check_schedule(b, 2) schedule = check_schedule(b, 4) - self.assertIs(schedule[0].ast.src[0].src[2].op, UOps.REDUCE_AXIS) + self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.REDUCE_AXIS) run_schedule(schedule) np.testing.assert_allclose(b.numpy(), a.numpy().sum(0)+a.numpy().max(0) + a.numpy().max(1)+a.numpy().sum(1)+2, atol=1e-4, rtol=1e-4) @@ -1352,7 +1352,7 @@ class TestIndexing(unittest.TestCase): with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)): lst = [xt] if isinstance(xt, Tensor) else xt s = Tensor.schedule(*lst) - kernels = [si for si in s if si.ast.op is UOps.SINK] + kernels = [si for si in s if si.ast.op is Ops.SINK] for si in kernels: verify_ast(si.ast) run_schedule(s) if FUSE_ARANGE: self.assertEqual(len(kernels), cnt) @@ -1607,20 +1607,20 @@ class TestIndexing(unittest.TestCase): self.assertLess(et, 1200) def test_no_rewrite_elementwise(self): - bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)] - ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) - ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop())) - sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),)) + bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)] + ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) + ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop())) + sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),)) rsink = graph_rewrite(sink, view_right) self.assertEqual(rsink.key, sink.key) def test_simple_store_reshape(self): - bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] - ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) - r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) - r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(())) + bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] + ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) + r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) + r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(())) r = r + ast_const(dtypes.int, 2, ()) - sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),)) + sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),)) rsink = graph_rewrite(sink, view_right) # NOTE: this AST is always correct in the entire lifecycle of graph_rewrite! # with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink) @@ -1628,21 +1628,21 @@ class TestIndexing(unittest.TestCase): verify_ast(rsink) def test_no_reshape_reduceop(self): - bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] - ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) - r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) - sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),)) + bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] + ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) + r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) + sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),)) rsink = graph_rewrite(sink, view_right) verify_ast(sink) self.assertEqual(sink.key, rsink.key) def test_reshape_many(self): - bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] - ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) - r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) - r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(())) + bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] + ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) + r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) + r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(())) for _ in range(24): r = r + ast_const(dtypes.int, 2, ()) - sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),)) + sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),)) rsink, et = timeit(graph_rewrite, sink, view_right) # NOTE: this AST is always correct in the entire lifecycle of graph_rewrite! # with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink) @@ -1656,11 +1656,11 @@ class TestIndexing(unittest.TestCase): sizes = [10*(i+1) for i in range(SZ)] tms: List[float] = [] for sz in sizes: - bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] - ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) - r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) + bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] + ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) + r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) for _ in range(sz): r = r + ast_const(dtypes.int, 2, ()) - sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),)) + sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),)) rsink, et = timeit(graph_rewrite, sink, view_right) with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink) verify_ast(rsink) @@ -1676,20 +1676,20 @@ class TestIndexing(unittest.TestCase): def test_swizzle_rewrite(self): # graph rewrite - sink = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501 - UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=( - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( - UOp(UOps.VIEW, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501 - UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=( - UOp(UOps.LOAD, dtypes.int, arg=None, src=( - x8:=UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501 - UOp(UOps.LOAD, dtypes.int, arg=None, src=( + sink = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501 + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=( + UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(Ops.VIEW, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501 + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=( + UOp(Ops.LOAD, dtypes.int, arg=None, src=( + x8:=UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501 + UOp(Ops.LOAD, dtypes.int, arg=None, src=( x8, - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501 + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501 sink = graph_rewrite(graph_rewrite(sink, view_left), view_right) # verify output k = Kernel(sink) @@ -1705,13 +1705,13 @@ class TestIndexing(unittest.TestCase): a = Tensor.randint(4,).realize() expected_out = a.numpy().sum(0)+1 # LazyBuffer to pre-rewrite AST - bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] - ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop())) - r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0,))) - swizzle_r = UOp(UOps.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(())) + bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] + ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop())) + r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0,))) + swizzle_r = UOp(Ops.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(())) const = ast_const(dtypes.int, 1, ()) alu = swizzle_r+const - sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),)) + sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),)) # graph rewrite sink = graph_rewrite(sink, view_right) # verify output @@ -1728,13 +1728,13 @@ class TestIndexing(unittest.TestCase): b = Tensor.randint(4,).realize() expected_out = a.numpy().sum(0)+b.numpy().sum(0)+2 # LazyBuffer to pre-rewrite AST - bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)] - ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop())) - r1 = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld1,), (BinaryOps.ADD, (0,))) - ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop())) - r2 = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld2,), (BinaryOps.ADD, (0,))) - alu = UOp(UOps.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(UOps.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(())) - sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+ast_const(dtypes.int, 2, ()),),),)) # noqa: E501 + bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)] + ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop())) + r1 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld1,), (BinaryOps.ADD, (0,))) + ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop())) + r2 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld2,), (BinaryOps.ADD, (0,))) + alu = UOp(Ops.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(Ops.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(())) + sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+ast_const(dtypes.int, 2, ()),),),)) # noqa: E501 # graph rewrite sink = graph_rewrite(sink, view_right) # verify output @@ -1745,51 +1745,51 @@ class TestIndexing(unittest.TestCase): np.testing.assert_equal(c.numpy(), expected_out) def test_swizzle_rewrite_alt(self): - swizzle = UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501 - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501 + swizzle = UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501 + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501 # there's an EXPAND pushing through the REDUCE_AXIS self.assertGreater(prod(swizzle.st.shape), prod(swizzle.src[0].st.shape)) ret = graph_rewrite(graph_rewrite(swizzle, view_left), view_right) # EXPAND is rewritten self.assertEqual(prod(ret.st.shape), prod(ret.src[0].st.shape)) # and pushed to the LOAD - new_load_st = unwrap([x for x in ret.parents if x.op is UOps.VIEW][0].st) + new_load_st = unwrap([x for x in ret.parents if x.op is Ops.VIEW][0].st) self.assertGreater(prod(new_load_st.shape), prod(ld_st.shape)) self.assertEqual(new_load_st.views[0].strides, (0, 9, 3, 0, 1, 0, 27)) def test_permute_rewrite(self): - sink = UOp(UOps.STORE, dtypes.void, arg=None, src=( - x1:=UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(1, ('METAL', 16384, dtypes.float)), src=()), - x2:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(UOps.CONTIGUOUS, dtypes.float, arg=None, src=( + sink = UOp(Ops.STORE, dtypes.void, arg=None, src=( + x1:=UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(1, ('METAL', 16384, dtypes.float)), src=()), + x2:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.CONTIGUOUS, dtypes.float, arg=None, src=( x1, - UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 32, 1, 1024), offset=0, mask=None, contiguous=False),)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=( - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=( - x11:=UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 16384, dtypes.float)), src=()), + UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 32, 1, 1024), offset=0, mask=None, contiguous=False),)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=( + x11:=UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 16384, dtypes.float)), src=()), x2,)),)), - UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(8, ('METAL', 256, dtypes.float)), src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)), - UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(10, ('METAL', 16, dtypes.float)), src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), - UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=( + UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(8, ('METAL', 256, dtypes.float)), src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)), + UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(10, ('METAL', 16, dtypes.float)), src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), + UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=( x11,)),)),)),)),)) @track_rewrites() def rewrite(sink): return graph_rewrite(graph_rewrite(sink, view_left), view_right) ret = rewrite(sink) - assert len([x for x in ret.sparents if x.op is UOps.VIEW and len(x.src) != 0]) == 0, f"unmerged views left in sink {ret}" + assert len([x for x in ret.sparents if x.op is Ops.VIEW and len(x.src) != 0]) == 0, f"unmerged views left in sink {ret}" if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/test/test_search.py b/test/test_search.py index 599aea748a..0e0d1ce3fa 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -3,7 +3,7 @@ import unittest from test.helpers import ast_const from tinygrad.codegen.kernel import Opt, OptOps from tinygrad.codegen.kernel import Kernel -from tinygrad.ops import UOp, UOps, BinaryOps +from tinygrad.ops import UOp, Ops, BinaryOps from tinygrad.engine.schedule import create_schedule from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search from tinygrad.device import Device, Buffer @@ -16,15 +16,15 @@ from tinygrad.shape.view import View class TestTimeLinearizer(unittest.TestCase): def test_reasonable_time(self): - si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is UOps.SINK][0] + si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is Ops.SINK][0] out = Buffer(Device.DEFAULT, si.outputs[0].size, si.outputs[0].dtype).allocate() - memops = {x.src[0].arg:x.src[-1].arg.real_size() for x in si.ast.parents if x.op is UOps.LOAD} + memops = {x.src[0].arg:x.src[-1].arg.real_size() for x in si.ast.parents if x.op is Ops.LOAD} rawbufs = [out] + [Buffer(Device.DEFAULT, memops[i], x.dtype).allocate() for i,x in enumerate(si.inputs, start=len(si.outputs))] tm = time_linearizer(Kernel(si.ast), rawbufs, allow_test_size=False, cnt=10, disable_cache=True) assert tm > 0 and tm != float('inf') def test_bufs_from_lin(self): - si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is UOps.SINK][0] + si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is Ops.SINK][0] rawbufs = bufs_from_lin(lin:=Kernel(si.ast)) assert len(rawbufs) == len(lin.membufs) == 2 assert all(r is not None for r in rawbufs) @@ -34,7 +34,7 @@ class TestTimeLinearizer(unittest.TestCase): def test_bufs_from_lin_alt(self): a = Tensor.randn(4, 4).realize() b = a+a[0] - si = [si for si in b.schedule() if si.ast.op is UOps.SINK][0] + si = [si for si in b.schedule() if si.ast.op is Ops.SINK][0] rawbufs = bufs_from_lin(k:=Kernel(si.ast)) assert len(rawbufs) == len(k.membufs) == 2 assert all(r is not None for r in rawbufs) @@ -46,12 +46,12 @@ class TestTimeLinearizer(unittest.TestCase): Ensure that the kernel count is not incremented by time_linearizer when clearing l2 """ # ast of Tensor.zeros(16).contiguous().realize() - ast = UOp(UOps.SINK, src=( - UOp(UOps.STORE, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),))), + ast = UOp(Ops.SINK, src=( + UOp(Ops.STORE, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),))), ast_const(dtypes.float, 0.0, st_src=( - UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(0,), offset=0, mask=None, contiguous=False),))),)),)),)) + UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(0,), offset=0, mask=None, contiguous=False),))),)),)),)) lin = Kernel(ast) bufs = bufs_from_lin(lin) @@ -103,37 +103,37 @@ class TestBEAM(unittest.TestCase): def test_filter_global_buffer(self): # taken from https://github.com/tinygrad/tinygrad/issues/4612 - ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501 - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (1,)), src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)), # noqa: E501 - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501 - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501 - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501 - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501 - UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501 + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501 + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (1,)), src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)), # noqa: E501 + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501 + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501 + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501 + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501 + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501 ast_const(dtypes.float, 1.4285714285714286, st_src=( - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) # noqa: E501 + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) # noqa: E501 lin = Kernel(ast) bufs = bufs_from_lin(lin) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 0e1b235e8c..8d56eb2d98 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -2,7 +2,7 @@ from typing import List import unittest, time from tinygrad import dtypes, Device from tinygrad.helpers import DEBUG -from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo +from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, Ops, UOp, KernelInfo from tinygrad.ops import UPat, PatternMatcher from tinygrad.renderer import Renderer from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index @@ -24,30 +24,30 @@ class TestGraphRewriteEfficiency(unittest.TestCase): c1 = UOp.const(dtypes.int, 1) c2 = UOp.const(dtypes.int, 2) st = time.perf_counter() - uops = [UOp(UOps.ALU, dtypes.int, (c1, c2), BinaryOps.ADD) for _ in range(10000)] + uops = [UOp(Ops.ALU, dtypes.int, (c1, c2), BinaryOps.ADD) for _ in range(10000)] et = time.perf_counter() - st print(f"created {len(uops)} uops in {et*1000:.2f} ms") def test_expand_rewrite(self): - sink = UOp(UOps.SINK, dtypes.void, arg=KernelInfo(local_dims=2, upcasted=4, dont_use_locals=False), src=( - UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 4, 64, 8, 16, 1, 1, 3, 3, 4, 1), + sink = UOp(Ops.SINK, dtypes.void, arg=KernelInfo(local_dims=2, upcasted=4, dont_use_locals=False), src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 4, 64, 8, 16, 1, 1, 3, 3, 4, 1), strides=(1179648, 9216, 1, 147456, 576, 0, 0, 64, 192, 36864, 0), offset=0, mask=None, contiguous=False),)), src=()), - UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 10)), src=( - UOp(UOps.CAST, dtypes.float, arg=None, src=( - UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 10)), src=( + UOp(Ops.CAST, dtypes.float, arg=None, src=( + UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=( + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=( View(shape=(1, 1024, 1, 64, 4, 17, 4, 17), strides=(0, 14400, 0, 225, 0, 15, 0, 1), offset=-16, mask=((0, 1), (0, 1024), (0, 1), (0, 64), (0, 4), (1, 16), (0, 4), (1, 16)), contiguous=False), View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(0, 73984, 4734976, 0, 4624, 295936, 68, 18, 1224, 0, 1), offset=0, mask=None, contiguous=False))), src=()),)), - UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), - UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=( + UOp(Ops.LOAD, dtypes.half, arg=None, src=( + UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=( View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(7200, 0, 230400, 900, 0, 14400, 15, 0, 0, 225, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)) lower_sink = rewrite_shapetracker_with_index(sink, Device[Device.DEFAULT].renderer) @@ -82,7 +82,7 @@ class TestGraphRewriteConst(unittest.TestCase): v1 = UOp.const(dtypes.int.vec(3), (0,1,2)) v2 = UOp.const(dtypes.int.vec(3), (5,6,7)) ret = graph_rewrite(v1+v2, sym) - self.assertEqual(ret.op, UOps.VCONST) + self.assertEqual(ret.op, Ops.VCONST) self.assertEqual(ret.dtype, dtypes.int.vec(3)) self.assertEqual(ret.arg, (5,7,9)) @@ -90,31 +90,31 @@ class TestGraphRewriteConst(unittest.TestCase): v1 = UOp.const(dtypes.int.vec(3), (0,1,2)) v2 = UOp.const(dtypes.int.vec(3), (2,1,0)) ret = graph_rewrite(v1+v2, sym) - self.assertEqual(ret.op, UOps.CONST) + self.assertEqual(ret.op, Ops.CONST) self.assertEqual(ret.dtype, dtypes.int.vec(3)) self.assertEqual(ret.arg, 2) class TestGraphRewrite(unittest.TestCase): def test_dedup(self): - v1 = UOp(UOps.DEFINE_VAR, dtypes.float) - v2 = UOp(UOps.DEFINE_VAR, dtypes.float) + v1 = UOp(Ops.DEFINE_VAR, dtypes.float) + v2 = UOp(Ops.DEFINE_VAR, dtypes.float) nout = graph_rewrite(v1+v2, PatternMatcher([])) self.assertIs(nout.src[0], nout.src[1]) # NOTE: this shows why we can't have a UOp in arg @unittest.expectedFailure def test_no_dedup_args(self): - a1 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("a1", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11))) - a2 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("a2", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11))) + a1 = UOp(Ops.DEFINE_VAR, dtypes.int, (), ("a1", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11))) + a2 = UOp(Ops.DEFINE_VAR, dtypes.int, (), ("a2", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11))) sink = a1.sink(a2) - define_vars = [x for x in graph_rewrite(sink, PatternMatcher([])).sparents if x.op is UOps.DEFINE_VAR] + define_vars = [x for x in graph_rewrite(sink, PatternMatcher([])).sparents if x.op is Ops.DEFINE_VAR] self.assertEqual(len(define_vars), 1) def test_simple(self): c1 = UOp.const(dtypes.float, 1.0) c2 = UOp.const(dtypes.float, 2.0) nout = graph_rewrite(c1+c2, simple_pm) - self.assertEqual(nout.op, UOps.CONST) + self.assertEqual(nout.op, Ops.CONST) self.assertEqual(nout.arg, 3.0) def test_depth_2_late(self): @@ -122,7 +122,7 @@ class TestGraphRewrite(unittest.TestCase): c2 = UOp.const(dtypes.float, 2.0) c3 = UOp.const(dtypes.float, 3.0) nout = graph_rewrite(c1*c2*(c3+c3), simple_pm) - self.assertEqual(nout.op, UOps.CONST) + self.assertEqual(nout.op, Ops.CONST) self.assertEqual(nout.arg, 12.0) def test_double(self): @@ -130,7 +130,7 @@ class TestGraphRewrite(unittest.TestCase): c2 = UOp.const(dtypes.float, 2.0) c3 = UOp.const(dtypes.float, 3.0) nout = graph_rewrite(c1+c2+c3, simple_pm) - self.assertEqual(nout.op, UOps.CONST) + self.assertEqual(nout.op, Ops.CONST) self.assertEqual(nout.arg, 6.0) def test_triple(self): @@ -139,7 +139,7 @@ class TestGraphRewrite(unittest.TestCase): c3 = UOp.const(dtypes.float, 3.0) c4 = UOp.const(dtypes.float, 4.0) nout = graph_rewrite(c1+c2+c3+c4, simple_pm) - self.assertEqual(nout.op, UOps.CONST) + self.assertEqual(nout.op, Ops.CONST) self.assertEqual(nout.arg, 10.0) def test_diamond(self): @@ -147,23 +147,23 @@ class TestGraphRewrite(unittest.TestCase): c2 = UOp.const(dtypes.float, 2.0) c3 = UOp.const(dtypes.float, 3.0) nout = graph_rewrite((c1+c2)+(c1+c3), simple_pm) - self.assertEqual(nout.op, UOps.CONST) + self.assertEqual(nout.op, Ops.CONST) self.assertEqual(nout.arg, 7.0) def test_magic_4(self): c1 = UOp.const(dtypes.int, 4.0) nout = graph_rewrite(c1, simple_pm) - self.assertEqual(nout.op, UOps.CONST) + self.assertEqual(nout.op, Ops.CONST) self.assertEqual(nout.arg, 3.0) def test_depth_2_fold(self): - v = UOp(UOps.DEFINE_VAR, dtypes.float) + v = UOp(Ops.DEFINE_VAR, dtypes.float) c1 = UOp.const(dtypes.float, 1.0) c2 = UOp.const(dtypes.float, 2.0) nout = graph_rewrite(v+c1+c2, simple_pm) - self.assertEqual(nout.op, UOps.ALU) - self.assertEqual(nout.src[0].op, UOps.DEFINE_VAR) - self.assertEqual(nout.src[1].op, UOps.CONST) + self.assertEqual(nout.op, Ops.ALU) + self.assertEqual(nout.src[0].op, Ops.DEFINE_VAR) + self.assertEqual(nout.src[1].op, Ops.CONST) self.assertEqual(nout.src[1].arg, 3.0) def test_commutative_work(self): @@ -182,77 +182,77 @@ class TestGraphRewrite(unittest.TestCase): b = UOp.variable('b', 0, 1) c = UOp.variable('c', 0, 1) d = UOp.variable('d', 0, 1) - outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b] + outs = [2+a, 2+a+d+3+b+c+4, UOp(Ops.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b] for out in outs: sink = graph_rewrite(out, sym) print(sink.render()) - self.assertEqual(sink.op, UOps.ALU) - self.assertEqual(sink.src[1].op, UOps.CONST) - self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 1) + self.assertEqual(sink.op, Ops.ALU) + self.assertEqual(sink.src[1].op, Ops.CONST) + self.assertEqual(len([x for x in sink.sparents if x.op is Ops.CONST]), 1) class TestUOpGraph(unittest.TestCase): def test_add_constant_fold(self): - c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) - c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) - out = UOp(UOps.ALU, dtypes.float, (c1, c2), BinaryOps.ADD) + c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) + c2 = UOp(Ops.CONST, dtypes.float, arg=2.0) + out = UOp(Ops.ALU, dtypes.float, (c1, c2), BinaryOps.ADD) uops = to_uops_list([out]) self.assertEqual(len(uops), 1) out = uops[-1] - self.assertEqual(out.op, UOps.CONST) + self.assertEqual(out.op, Ops.CONST) self.assertEqual(out.arg, 3.0) def test_where_same_fold(self): v = UOp.variable('tmp', 0, 1) - c0 = UOp(UOps.CONST, dtypes.int, arg=0) - vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE) - c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) - out = UOp(UOps.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE) + c0 = UOp(Ops.CONST, dtypes.int, arg=0) + vc = UOp(Ops.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE) + c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) + out = UOp(Ops.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE) uops = to_uops_list([out]) self.assertEqual(len(uops), 1) out = uops[-1] - self.assertEqual(out.op, UOps.CONST) + self.assertEqual(out.op, Ops.CONST) self.assertEqual(out.arg, 1.0) def test_where_const_fold(self): - bf = UOp(UOps.CONST, dtypes.bool, arg=False) - c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) - c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) - out = UOp(UOps.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE) + bf = UOp(Ops.CONST, dtypes.bool, arg=False) + c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) + c2 = UOp(Ops.CONST, dtypes.float, arg=2.0) + out = UOp(Ops.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE) uops = to_uops_list([out]) self.assertEqual(len(uops), 1) out = uops[-1] - self.assertEqual(out.op, UOps.CONST) + self.assertEqual(out.op, Ops.CONST) self.assertEqual(out.arg, 2.0) def test_const_cast(self): - bf = UOp(UOps.CONST, dtypes.bool, arg=False) - out = UOp(UOps.CAST, dtypes.int, (bf,)) + bf = UOp(Ops.CONST, dtypes.bool, arg=False) + out = UOp(Ops.CAST, dtypes.int, (bf,)) uops = to_uops_list([out]) self.assertEqual(len(uops), 1) out = uops[-1] - self.assertEqual(out.op, UOps.CONST) + self.assertEqual(out.op, Ops.CONST) self.assertEqual(out.arg, 0) @unittest.skip("this test isn't valid uops") def test_noop_vectorize_fold(self): - d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0) + d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0) idx = UOp.const(dtypes.int, 0) - ld = UOp(UOps.LOAD, dtypes.float.vec(2), (d0, idx)) - vec = UOp(UOps.VECTORIZE, dtypes.float.vec(2), (ld,)) - x = UOp(UOps.GEP, dtypes.float, (vec, ), arg=0) - alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT) - out = UOp(UOps.STORE, dtypes.void, (d0, idx, alu)) + ld = UOp(Ops.LOAD, dtypes.float.vec(2), (d0, idx)) + vec = UOp(Ops.VECTORIZE, dtypes.float.vec(2), (ld,)) + x = UOp(Ops.GEP, dtypes.float, (vec, ), arg=0) + alu = UOp(Ops.ALU, dtypes.float, (x, ), UnaryOps.SQRT) + out = UOp(Ops.STORE, dtypes.void, (d0, idx, alu)) uops = to_uops_list([out]) - self.assertEqual(len([x for x in uops if x.op is UOps.VECTORIZE]), 0) + self.assertEqual(len([x for x in uops if x.op is Ops.VECTORIZE]), 0) def test_gep_vec_fold(self): - d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) - d1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1) - d2 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 2) + d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) + d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1) + d2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 2) idx = UOp.const(dtypes.int, 0) def _test_vec(geps, count=4): - vec = UOp(UOps.VECTORIZE, dtypes.float.vec(count), geps) - out = UOp(UOps.STORE, dtypes.void, (d0, idx, vec)) + vec = UOp(Ops.VECTORIZE, dtypes.float.vec(count), geps) + out = UOp(Ops.STORE, dtypes.void, (d0, idx, vec)) uops = to_uops_list([out]) if DEBUG >= 4: from tinygrad import Device @@ -260,53 +260,53 @@ class TestUOpGraph(unittest.TestCase): return uops[-1].src[-1] # possible - val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx)) - xyzw = tuple(UOp(UOps.GEP, dtypes.float, (val,), (i,)) for i in range(4)) - self.assertIs(_test_vec(xyzw).op, UOps.LOAD) + val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1, idx)) + xyzw = tuple(UOp(Ops.GEP, dtypes.float, (val,), (i,)) for i in range(4)) + self.assertIs(_test_vec(xyzw).op, Ops.LOAD) # unaligned - val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx)) - wzyx = tuple(UOp(UOps.GEP, dtypes.float, (val,), (i,)) for i in reversed(range(4))) - self.assertIs(_test_vec(wzyx).op, UOps.VECTORIZE) + val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1, idx)) + wzyx = tuple(UOp(Ops.GEP, dtypes.float, (val,), (i,)) for i in reversed(range(4))) + self.assertIs(_test_vec(wzyx).op, Ops.VECTORIZE) # different_size - val = UOp(UOps.LOAD, dtypes.float.vec(2), (d1, idx)) - xy = tuple(UOp(UOps.GEP, dtypes.float, (val, ), (i,)) for i in range(2)) - self.assertIs(_test_vec(xy+xy).op, UOps.VECTORIZE) - val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx)) - xy = tuple(UOp(UOps.GEP, dtypes.float, (val, ), (i,)) for i in range(2)) - self.assertIs(_test_vec(xy, count=2).op, UOps.VECTORIZE) + val = UOp(Ops.LOAD, dtypes.float.vec(2), (d1, idx)) + xy = tuple(UOp(Ops.GEP, dtypes.float, (val, ), (i,)) for i in range(2)) + self.assertIs(_test_vec(xy+xy).op, Ops.VECTORIZE) + val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1, idx)) + xy = tuple(UOp(Ops.GEP, dtypes.float, (val, ), (i,)) for i in range(2)) + self.assertIs(_test_vec(xy, count=2).op, Ops.VECTORIZE) # different vals - val1 = UOp(UOps.LOAD, dtypes.float.vec(2), (d1, idx)) - val2 = UOp(UOps.LOAD, dtypes.float.vec(2), (d2, idx)) - xy1 = tuple(UOp(UOps.GEP, dtypes.float, (val1, ), (i,)) for i in range(2)) - xy2 = tuple(UOp(UOps.GEP, dtypes.float, (val2, ), (i,)) for i in range(2)) - self.assertIs(_test_vec(xy1+xy2).op, UOps.VECTORIZE) + val1 = UOp(Ops.LOAD, dtypes.float.vec(2), (d1, idx)) + val2 = UOp(Ops.LOAD, dtypes.float.vec(2), (d2, idx)) + xy1 = tuple(UOp(Ops.GEP, dtypes.float, (val1, ), (i,)) for i in range(2)) + xy2 = tuple(UOp(Ops.GEP, dtypes.float, (val2, ), (i,)) for i in range(2)) + self.assertIs(_test_vec(xy1+xy2).op, Ops.VECTORIZE) def test_gep_vec_const_fold(self): for vec_size in [2, 4, 8]: consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)] - vec = UOp(UOps.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts)) - uops = to_uops_list([UOp(UOps.GEP, dtypes.float, (vec,), (i,)) for i in range(vec_size)]) + vec = UOp(Ops.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts)) + uops = to_uops_list([UOp(Ops.GEP, dtypes.float, (vec,), (i,)) for i in range(vec_size)]) for uop, const in zip(uops, consts): self.assertEqual(uop, const) def test_wmma_vectorize_fold(self): for i in [2, 4, 8]: - vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i))) - var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i)) + vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i))) + var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i)) acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i)) - wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) + wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc)) uops = to_uops_list([wmma]) self.assertEqual(uops[0], acc) self.assertEqual(len(uops), 1) for i in [2, 4, 8]: - var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i)) - vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i))) + var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i)) + vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i))) acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i)) - wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) + wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc)) uops = to_uops_list([wmma]) self.assertEqual(uops[0], acc) self.assertEqual(len(uops), 1) @@ -314,109 +314,109 @@ class TestUOpGraph(unittest.TestCase): @unittest.skip("wmma is wrong here, it needs an arg") def test_wmma_vectorize_no_fold(self): for i in [4, 8]: - vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), + vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) + - tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2))) - var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) - acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) - wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) + tuple(UOp(Ops.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2))) + var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) + acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) + wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc)) uops = to_uops_list([wmma]) self.assertEqual(uops[-1], wmma) for i in [4, 8]: - var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) - vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), + var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) + vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) + - tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2))) - acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) - wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) + tuple(UOp(Ops.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2))) + acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) + wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc)) uops = to_uops_list([wmma]) self.assertEqual(uops[-1], wmma) for i in [2, 4, 8]: - vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), + vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i))) - var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) - acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) - wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) + var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) + acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) + wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc)) uops = to_uops_list([wmma]) self.assertEqual(uops[-1], wmma) for i in [2, 4, 8]: - var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) - vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), + var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) + vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i))) - acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) - wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) + acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) + wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc)) uops = to_uops_list([wmma]) self.assertEqual(uops[-1], wmma) def test_cast_alu_fold(self): - d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0) - d1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1) + d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0) + d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1) idx = UOp.const(dtypes.int, 0) - ld = UOp(UOps.LOAD, dtypes.int, (d1, idx)) + ld = UOp(Ops.LOAD, dtypes.int, (d1, idx)) alu = ld.lt(1).cast(dtypes.bool) - out = UOp(UOps.STORE, dtypes.void, (d0, idx, alu)) + out = UOp(Ops.STORE, dtypes.void, (d0, idx, alu)) uops = to_uops_list([out]) - self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 0) + self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0) def test_double_cast_fold(self): - d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0) - d1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1) + d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0) + d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1) idx = UOp.const(dtypes.int, 0) - ld = UOp(UOps.LOAD, dtypes.int, (d1, idx)) + ld = UOp(Ops.LOAD, dtypes.int, (d1, idx)) alu = ld.cast(dtypes.float).cast(dtypes.float) - out = UOp(UOps.STORE, dtypes.void, (d0, idx, alu)) + out = UOp(Ops.STORE, dtypes.void, (d0, idx, alu)) uops = to_uops_list([out]) - self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1) + self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1) def test_depth_2_const_fold(self): v = UOp.variable("tmp", 0, 1) - c2 = UOp(UOps.CONST, dtypes.int, arg=2) - c4 = UOp(UOps.CONST, dtypes.int, arg=4) - vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD) - out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD) + c2 = UOp(Ops.CONST, dtypes.int, arg=2) + c4 = UOp(Ops.CONST, dtypes.int, arg=4) + vc = UOp(Ops.ALU, dtypes.int, (v, c2), BinaryOps.ADD) + out = UOp(Ops.ALU, dtypes.int, (vc, c4), BinaryOps.ADD) uops = to_uops_list([out]) self.assertEqual(len(uops), 3) out = uops[-1] - self.assertEqual(out.op, UOps.ALU) + self.assertEqual(out.op, Ops.ALU) self.assertEqual(out.arg, BinaryOps.ADD) - self.assertEqual(out.src[1].op, UOps.CONST) + self.assertEqual(out.src[1].op, Ops.CONST) self.assertEqual(out.src[1].arg, 6) def test_fold_gated_load(self): - glbl0 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) - glbl1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1) - glbl2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2) + glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) + glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1) + glbl2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2) idx = UOp.const(dtypes.int, 0) - ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, False))) - ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, True))) - uops = to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, idx, ld1+ld0))]) + ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, False))) + ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, True))) + uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0, idx, ld1+ld0))]) ld0 = uops[-1].src[-1] # the gate and invalid value are deleted from ld1 self.assertEqual(ld0, UOp.load(glbl2.index(idx), dtype=dtypes.int)) def test_fold_gated_load_local(self): - glbl0 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) - smem = UOp(UOps.DEFINE_LOCAL, dtypes.int.ptr(local=True), (), ("temp", 1)) - lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16)) - st = UOp(UOps.STORE, dtypes.void, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int))) - barrier = UOp(UOps.BARRIER, dtypes.void, (st, )) - ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, False), barrier)) - ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, True), barrier)) - uops = to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, lidx, ld1+ld0))]) + glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) + smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(local=True), (), ("temp", 1)) + lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 16)) + st = UOp(Ops.STORE, dtypes.void, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int))) + barrier = UOp(Ops.BARRIER, dtypes.void, (st, )) + ld0 = UOp(Ops.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, False), barrier)) + ld1 = UOp(Ops.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, True), barrier)) + uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0, lidx, ld1+ld0))]) ld0 = uops[-1].src[-1] # the gate and invalid value are deleted from ld1 self.assertEqual(ld0.src[0], smem.index(lidx+2)) def test_fold_gated_store(self): - glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) + glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) idx0 = UOp.const(dtypes.int, 0) idx1 = UOp.const(dtypes.int, 0) val = UOp.const(dtypes.int, 42) - st0 = UOp(UOps.STORE, dtypes.void, (glbl, idx0, val, UOp.const(dtypes.bool, False))) - st1 = UOp(UOps.STORE, dtypes.void, (glbl, idx1, val, UOp.const(dtypes.bool, True))) + st0 = UOp(Ops.STORE, dtypes.void, (glbl, idx0, val, UOp.const(dtypes.bool, False))) + st1 = UOp(Ops.STORE, dtypes.void, (glbl, idx1, val, UOp.const(dtypes.bool, True))) uops = to_uops_list([st0, st1]) # only the second store happens self.assertEqual(len(uops), 5) @@ -424,23 +424,23 @@ class TestUOpGraph(unittest.TestCase): @unittest.skip("this is a uop type error") def test_asserts_bad_gate(self): - glbl0 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) + glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) idx = UOp.const(dtypes.int, 0) bad_gate = UOp.const(dtypes.int, 1) - with self.assertRaises(AssertionError): to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))]) + with self.assertRaises(AssertionError): to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))]) def test_switched_range_order(self): - glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) + glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) c0 = UOp.const(dtypes.int, 0) c2 = UOp.const(dtypes.int, 2) cf = UOp.const(dtypes.float, 0.0) - r1 = UOp(UOps.RANGE, dtypes.int, (c0, c2), (1, 0, False)) - r2 = UOp(UOps.RANGE, dtypes.int, (c0, c2), (1, 1, False)) - alu = UOp(UOps.ALU, dtypes.int, (r2, r1), BinaryOps.MUL) - store = UOp(UOps.STORE, dtypes.void, (glbl, alu, cf)) + r1 = UOp(Ops.RANGE, dtypes.int, (c0, c2), (1, 0, False)) + r2 = UOp(Ops.RANGE, dtypes.int, (c0, c2), (1, 1, False)) + alu = UOp(Ops.ALU, dtypes.int, (r2, r1), BinaryOps.MUL) + store = UOp(Ops.STORE, dtypes.void, (glbl, alu, cf)) uops = to_uops_list([store]) - ranges = [x for x in uops if x.op is UOps.RANGE] - endranges = [x for x in uops if x.op is UOps.ENDRANGE] + ranges = [x for x in uops if x.op is Ops.RANGE] + endranges = [x for x in uops if x.op is Ops.ENDRANGE] # ranges are closed in the right order self.assertEqual(endranges[-1].src[0], ranges[0]) @@ -449,90 +449,90 @@ def float4_rewrite(sink): return full_graph_rewrite(sink, Renderer()) class TestExpander(unittest.TestCase): def test_expand_add_broadcast(self): - e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),)) + e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),)) sink = expander_rewrite(e1+3) - assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 4 + assert sink.op is Ops.EXPAND and len(sink.src[0].arg) == 4 self.assertTupleEqual(sink.src[0].arg, (3,4,5,6)) def test_contract_simple(self): - e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),)) - con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),)) + e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),)) + con = UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),)) sink = expander_rewrite(con) - self.assertEqual(sink.op, UOps.VCONST) + self.assertEqual(sink.op, Ops.VCONST) self.assertTupleEqual(sink.arg, (0,1,2,3)) def test_contract_axis_1(self): - e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4))) - con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),)) + e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4))) + con = UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),)) sink = expander_rewrite(con) - assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 16 and sink.arg == ((2,4),) - assert sink.src[0].op is UOps.VCONST + assert sink.op is Ops.EXPAND and len(sink.src[0].arg) == 16 and sink.arg == ((2,4),) + assert sink.src[0].op is Ops.VCONST self.assertTupleEqual(sink.src[0].arg[0:4], (0,4,8,12)) self.assertTupleEqual(sink.src[0].arg[12:], (3,7,11,15)) def test_contract_axis_2(self): - e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4))) - con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2,4),)) + e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4))) + con = UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((2,4),)) sink = expander_rewrite(con) - assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 16 and sink.arg == ((1,4),) - assert sink.src[0].op is UOps.VCONST + assert sink.op is Ops.EXPAND and len(sink.src[0].arg) == 16 and sink.arg == ((1,4),) + assert sink.src[0].op is Ops.VCONST self.assertTupleEqual(sink.src[0].arg[0:4], (0,1,2,3)) self.assertTupleEqual(sink.src[0].arg[12:], (12,13,14,15)) def test_contract_axis_2_big(self): - e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2))) - con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),)) + e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2))) + con = UOp(Ops.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),)) sink = expander_rewrite(con) - assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (3, 2), (4, 2)) + assert sink.op is Ops.EXPAND and sink.arg == ((1, 2), (3, 2), (4, 2)) self.assertTupleEqual(sink.src[0].arg[0:2], (0,4)) self.assertTupleEqual(sink.src[0].arg[12:14], (10,14)) def test_contract_multi_axis(self): - e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2))) - sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((3, 2), (2, 2)))) - assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2)) + e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2))) + sink = expander_rewrite(UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((3, 2), (2, 2)))) + assert sink.op is Ops.EXPAND and sink.arg == ((1, 2), (4, 2)) self.assertTupleEqual(sink.src[0].arg[0:4], (0, 4, 2, 6)) - sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2, 2), (3, 2)))) - assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2)) + sink = expander_rewrite(UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((2, 2), (3, 2)))) + assert sink.op is Ops.EXPAND and sink.arg == ((1, 2), (4, 2)) self.assertTupleEqual(sink.src[0].arg[0:4], (0, 2, 4, 6)) def test_contract_mid(self): - e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(8), tuple(x for x in range(8))),), ((1,2),(2,2),(3,2))) - con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),)) + e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(8), tuple(x for x in range(8))),), ((1,2),(2,2),(3,2))) + con = UOp(Ops.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),)) sink = expander_rewrite(con) - assert sink.op is UOps.EXPAND and sink.arg == ((1,2),(3,2)) - assert sink.src[0].op is UOps.VCONST and len(sink.src[0].arg) == 8 + assert sink.op is Ops.EXPAND and sink.arg == ((1,2),(3,2)) + assert sink.src[0].op is Ops.VCONST and len(sink.src[0].arg) == 8 self.assertTupleEqual(sink.src[0].arg, (0,2,1,3,4,6,5,7)) def test_contract_no_expand(self): - e1 = UOp(UOps.DEFINE_VAR, dtypes.int) - con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),)) + e1 = UOp(Ops.DEFINE_VAR, dtypes.int) + con = UOp(Ops.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),)) sink = expander_rewrite(con) - assert sink.op is UOps.VECTORIZE and len(sink.src) == 2 + assert sink.op is Ops.VECTORIZE and len(sink.src) == 2 assert sink.src[0] == sink.src[1] def test_contract_half_expand(self): - e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),)) - con = UOp(UOps.CONTRACT, dtypes.int.vec(8), (e1,), ((1,4), (2,2))) + e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),)) + con = UOp(Ops.CONTRACT, dtypes.int.vec(8), (e1,), ((1,4), (2,2))) sink = expander_rewrite(con) - assert sink.op is UOps.VCONST and len(sink.arg) == 8 + assert sink.op is Ops.VCONST and len(sink.arg) == 8 assert sink.arg[0] == sink.arg[1] assert sink.arg[0] != sink.arg[2] assert sink.arg[6] == sink.arg[7] def test_expand_same_axis(self): - e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),)) - e2 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),)) + e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),)) + e2 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),)) sink = expander_rewrite(e1+e2) - self.assertEqual(sink.op, UOps.EXPAND) - self.assertEqual(sink.src[0].op, UOps.VCONST) + self.assertEqual(sink.op, Ops.EXPAND) + self.assertEqual(sink.src[0].op, Ops.VCONST) self.assertTupleEqual(sink.src[0].arg, (0,5,10,15)) def test_expand_different_axis(self, flip=False): - e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),)) - e2 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((2,4),)) + e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),)) + e2 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((2,4),)) sink = expander_rewrite((e2+e1) if flip else (e1+e2)) - assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 16 + assert sink.op is Ops.EXPAND and len(sink.src[0].arg) == 16 assert sink.arg == ((1, 4), (2, 4)) self.assertTupleEqual(sink.src[0].arg, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)) @@ -540,47 +540,47 @@ class TestExpander(unittest.TestCase): @unittest.skip("no longer supported") def test_reduce_known_axis(self): - e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) - sink = UOp(UOps.REDUCE, dtypes.int, (3*e1,e1), BinaryOps.ADD) + e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) + sink = UOp(Ops.REDUCE, dtypes.int, (3*e1,e1), BinaryOps.ADD) sink = expander_rewrite(sink) - assert sink.op is UOps.CONST + assert sink.op is Ops.CONST self.assertEqual(sink.arg, 3*(0+1+2+3)) @unittest.skip("no longer supported") def test_reduce_const(self): - e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) - sink = UOp(UOps.REDUCE, dtypes.int, (UOp.const(dtypes.int, 3), e1), BinaryOps.ADD) + e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) + sink = UOp(Ops.REDUCE, dtypes.int, (UOp.const(dtypes.int, 3), e1), BinaryOps.ADD) sink = expander_rewrite(sink) - assert sink.op is UOps.CONST + assert sink.op is Ops.CONST self.assertEqual(sink.arg, 3*4) @unittest.skip("no longer supported") def test_double_expand(self): - e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),)) - e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((2,4),)) - e = UOp(UOps.EXPAND, dtypes.int, (e1, e2), ((1,2),)) + e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),)) + e2 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((2,4),)) + e = UOp(Ops.EXPAND, dtypes.int, (e1, e2), ((1,2),)) sink = expander_rewrite(e) - assert sink.op is UOps.EXPAND and len(sink.src) == 8 + assert sink.op is Ops.EXPAND and len(sink.src) == 8 assert sink.arg == ((1, 2), (2, 4)) self.assertListEqual([x.arg for x in sink.src], [0,1,2,3,4,5,6,7]) @unittest.skip("no longer supported") def test_double_expand_reverse(self): - e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) - e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,4),)) - e = UOp(UOps.EXPAND, dtypes.int, (e1, e2), ((2,2),)) + e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) + e2 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,4),)) + e = UOp(Ops.EXPAND, dtypes.int, (e1, e2), ((2,2),)) sink = expander_rewrite(e) - assert sink.op is UOps.EXPAND and len(sink.src) == 8 + assert sink.op is Ops.EXPAND and len(sink.src) == 8 assert sink.arg == ((1, 4), (2, 2)) self.assertListEqual([x.arg for x in sink.src], [0, 4, 1, 5, 2, 6, 3, 7]) @unittest.skip("no longer supported") def test_double_expand_middle(self): - e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,2),(3,2))) - e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,2),(3,2))) - e = UOp(UOps.EXPAND, dtypes.int, (e1, e2), ((2,2),)) + e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,2),(3,2))) + e2 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,2),(3,2))) + e = UOp(Ops.EXPAND, dtypes.int, (e1, e2), ((2,2),)) sink = expander_rewrite(e) - assert sink.op is UOps.EXPAND and len(sink.src) == 8 + assert sink.op is Ops.EXPAND and len(sink.src) == 8 assert sink.arg == ((1, 2), (2, 2), (3, 2)) self.assertListEqual([x.arg for x in sink.src], [0, 1, 4, 5, 2, 3, 6, 7]) @@ -588,106 +588,106 @@ class TestExpander(unittest.TestCase): @unittest.expectedFailure @unittest.skip def test_reduce_different_axis(self): - e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) - e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),)) - sink = UOp(UOps.REDUCE, dtypes.int, (e1,e2), BinaryOps.ADD) + e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) + e2 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),)) + sink = UOp(Ops.REDUCE, dtypes.int, (e1,e2), BinaryOps.ADD) sink = expander_rewrite(sink) print(sink) class TestLoadStoreFolder(unittest.TestCase): def test_simple_load_fold(self): - buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr()) - load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(4)] - sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) + load = [UOp(Ops.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(4)] + sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = float4_rewrite(sink.sink()) - assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1 + assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 1 def test_two_load_fold(self): - buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr()) - load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(8)] - sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) + load = [UOp(Ops.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(8)] + sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = float4_rewrite(sink.sink()) - assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 2 + assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 2 def test_simple_load_fold_gated(self): - buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr()) - gate = UOp(UOps.DEFINE_VAR, dtypes.bool) - load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)] - sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) + gate = UOp(Ops.DEFINE_VAR, dtypes.bool) + load = [UOp(Ops.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)] + sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = float4_rewrite(sink.sink()) - assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1 - single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0] - self.assertEqual(single_load.src[1].op, UOps.VECTORIZE) + assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 1 + single_load = [x for x in sink.sparents if x.op is Ops.LOAD][0] + self.assertEqual(single_load.src[1].op, Ops.VECTORIZE) def test_simple_load_dont_fold_different_gated(self): - buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr()) + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) gate = UOp.variable("g1", False, True, dtypes.bool) gate2 = UOp.variable("g2", False, True, dtypes.bool) - load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate if i == 0 else gate2)) for i in range(4)] - sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) + load = [UOp(Ops.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate if i == 0 else gate2)) for i in range(4)] + sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = float4_rewrite(sink.sink()) - assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 3 + assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 3 def test_simple_store_fold(self): - buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr()) - load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0))) for i in range(4)] - sink = UOp(UOps.SINK, dtypes.void, tuple(load)) + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) + load = [UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0))) for i in range(4)] + sink = UOp(Ops.SINK, dtypes.void, tuple(load)) sink = float4_rewrite(sink) - assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1 + assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 1 def test_simple_store_fold_gate(self): - buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr()) + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) gate = UOp.variable("g1", False, True, dtypes.bool) - load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)] - sink = UOp(UOps.SINK, dtypes.void, tuple(load)) + load = [UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)] + sink = UOp(Ops.SINK, dtypes.void, tuple(load)) sink = float4_rewrite(sink) - assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1 - one_store = [x for x in sink.sparents if x.op is UOps.STORE][0] + assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 1 + one_store = [x for x in sink.sparents if x.op is Ops.STORE][0] assert len(one_store.src) == 3 assert str(one_store.src[2]) == str(gate) # huh, why do i need str here? def test_simple_store_dont_fold(self): - buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr()) + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr()) gate = UOp.variable("g1", False, True, dtypes.bool) gate2 = UOp.variable("g2", False, True, dtypes.bool) - load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)] - sink = UOp(UOps.SINK, dtypes.void, tuple(load)) + load = [UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)] + sink = UOp(Ops.SINK, dtypes.void, tuple(load)) sink = float4_rewrite(sink) - assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 3 + assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 3 class TestIFUOps(unittest.TestCase): def test_create_ifs(self): - gbuf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) - sbuf = UOp(UOps.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 4)) - valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5) - lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4)) + gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) + sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 4)) + valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5) + lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 4)) gate = valid&(lidx.ne(2)) idx = UOp.const(dtypes.int, 0) - st = UOp(UOps.STORE, dtypes.void, (sbuf, idx, UOp.const(dtypes.float, 42))) - barrier = UOp(UOps.BARRIER, dtypes.void, (st,)) - lbuf = UOp(UOps.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, 0), barrier)) - store = UOp(UOps.STORE, dtypes.void, (gbuf, UOp.const(dtypes.int, 0), lbuf, gate)) - sink = UOp(UOps.SINK, dtypes.void, (store,)) + st = UOp(Ops.STORE, dtypes.void, (sbuf, idx, UOp.const(dtypes.float, 42))) + barrier = UOp(Ops.BARRIER, dtypes.void, (st,)) + lbuf = UOp(Ops.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, 0), barrier)) + store = UOp(Ops.STORE, dtypes.void, (gbuf, UOp.const(dtypes.int, 0), lbuf, gate)) + sink = UOp(Ops.SINK, dtypes.void, (store,)) sink = full_graph_rewrite(sink) - if_uops = [u for u in sink.parents if u.op is UOps.IF] + if_uops = [u for u in sink.parents if u.op is Ops.IF] self.assertEqual(len(if_uops), 1) self.assertEqual(if_uops[0].src[0], gate) for st in sink.src: self.assertEqual(len(st.src), 2) def test_expand_ifs_one_gate(self): - gbuf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) - sbuf = UOp(UOps.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 16)) - valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1) - lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16)) + gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) + sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 16)) + valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1) + lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 16)) gate = valid&(lidx.ne(2)) - st = UOp(UOps.STORE, dtypes.void, (sbuf, lidx, UOp.const(dtypes.float, 42))) - barrier = UOp(UOps.BARRIER, dtypes.void, (st,)) - lbufs = [UOp(UOps.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, i), barrier)) for i in range(4)] - stores = [UOp(UOps.STORE, dtypes.void, (gbuf, UOp.const(dtypes.int, i), lbufs[i], gate)) for i in range(4)] - sink = UOp(UOps.SINK, dtypes.void, tuple(stores)) + st = UOp(Ops.STORE, dtypes.void, (sbuf, lidx, UOp.const(dtypes.float, 42))) + barrier = UOp(Ops.BARRIER, dtypes.void, (st,)) + lbufs = [UOp(Ops.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, i), barrier)) for i in range(4)] + stores = [UOp(Ops.STORE, dtypes.void, (gbuf, UOp.const(dtypes.int, i), lbufs[i], gate)) for i in range(4)] + sink = UOp(Ops.SINK, dtypes.void, tuple(stores)) sink = full_graph_rewrite(sink) - if_uops = [u for u in sink.parents if u.op is UOps.IF] + if_uops = [u for u in sink.parents if u.op is Ops.IF] self.assertEqual(len(if_uops), 1) self.assertEqual(if_uops[0].src[0], gate) for st in sink.src: @@ -696,14 +696,14 @@ class TestIFUOps(unittest.TestCase): # this will be fixed with the merge gated stores bounty @unittest.expectedFailure def test_expand_ifs_dumb(self): - buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) - valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5) - lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4)) + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) + valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5) + lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 4)) gate = valid&(lidx.ne(2)) - stores = [UOp(UOps.STORE, dtypes.void, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)] - sink = UOp(UOps.SINK, dtypes.void, tuple(stores)) + stores = [UOp(Ops.STORE, dtypes.void, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)] + sink = UOp(Ops.SINK, dtypes.void, tuple(stores)) sink = full_graph_rewrite(sink) - if_uops = [u for u in sink.parents if u.op is UOps.IF] + if_uops = [u for u in sink.parents if u.op is Ops.IF] self.assertEqual(len(if_uops), 1) self.assertEqual(if_uops[0].src[0], gate) for st in sink.src: diff --git a/test/test_uops.py b/test/test_uops.py index 67c681bc03..47b80780d0 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.helpers import CI, DEBUG, getenv, Context from tinygrad.dtype import dtypes, DType from tinygrad.device import Buffer, Device -from tinygrad.ops import UOps, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401 +from tinygrad.ops import Ops, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401 from tinygrad.renderer import Program from tinygrad.engine.schedule import create_schedule, to_si from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel @@ -23,18 +23,18 @@ def _uops_to_prg(uops_list): return CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops, global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None)) -def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], src:Tuple[UOp, ...], arg:Any=None) -> UOp: +def uop(uops:List[UOp], uop:Ops, dtype:Optional[DType], src:Tuple[UOp, ...], arg:Any=None) -> UOp: uops.append(UOp(uop, dtype, tuple(src), arg)) return uops[-1] def _test_single_value(vals, op, dts): uops = [] output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1] - buf_store = uop(uops, UOps.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) - buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)] - loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts)) - alu = uop(uops, UOps.ALU, output_dtype, loads, op) - out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) + buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) + buf_loads = [uop(uops, Ops.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)] + loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i], uop(uops, Ops.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts)) + alu = uop(uops, Ops.ALU, output_dtype, loads, op) + out = uop(uops, Ops.STORE, dtypes.void, (buf_store, uop(uops, Ops.CONST, dtypes.int32, (), 0), alu)) buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=_to_np_dtype(dtype)).data) for a,dtype in zip(vals, dts)] prg = _uops_to_prg([out]) @@ -46,10 +46,10 @@ def _test_single_value(vals, op, dts): def _test_single_value_const(vals, op, dts): uops = [] output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1] - buf_store = uop(uops, UOps.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) - loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts)) - alu = uop(uops, UOps.ALU, output_dtype, loads, op) - out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) + buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) + loads = (uop(uops, Ops.CONST, dtype, [], a) for a,dtype in zip(vals, dts)) + alu = uop(uops, Ops.ALU, output_dtype, loads, op) + out = uop(uops, Ops.STORE, dtypes.void, (buf_store, uop(uops, Ops.CONST, dtypes.int32, (), 0), alu)) buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() prg = _uops_to_prg([out]) prg.exec([buf]) @@ -59,9 +59,9 @@ def _test_single_value_const(vals, op, dts): def _test_uops_result(output_dtype, uops, res): # uops = [] - buf_store = uop(uops, UOps.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) + buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) # res = output_fn(uops) - out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), res)) + out = uop(uops, Ops.STORE, dtypes.void, (buf_store, uop(uops, Ops.CONST, dtypes.int32, (), 0), res)) buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() prg = _uops_to_prg([out]) prg.exec([buf]) @@ -244,63 +244,63 @@ class TestConstantFolding(unittest.TestCase): si = create_schedule([t.lazydata]) assert len(si) == 1 ji = lower_schedule_item(si[-1]) - assert any(uop.op is UOps.BITCAST for uop in ji.prg.p.uops), f"{[uop.op for uop in ji.prg.p.uops]} does not contain bitcast" + assert any(uop.op is Ops.BITCAST for uop in ji.prg.p.uops), f"{[uop.op for uop in ji.prg.p.uops]} does not contain bitcast" class TestGatedStoreRewrite(unittest.TestCase): @unittest.expectedFailure def test_tiny_gate_store(self): - gmem = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) - gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4)) + gmem = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) + gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4)) idx = gidx0 * UOp.const(dtypes.int, 2) val = UOp.const(dtypes.float, 42.0) gate = gidx0.lt(UOp.const(dtypes.int, 1)) - store = UOp(UOps.STORE, dtypes.void, (gmem, idx, val, gate)) + store = UOp(Ops.STORE, dtypes.void, (gmem, idx, val, gate)) uops = to_uops_list([store]) if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops)) - if_uop = next(u for u in uops if u.op is UOps.IF) - endif = next(u for u in uops if u.op is UOps.ENDIF) + if_uop = next(u for u in uops if u.op is Ops.IF) + endif = next(u for u in uops if u.op is Ops.ENDIF) assert endif.src[0] is if_uop gated_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)]) self.assertEqual(len(gated_uops), 1) - self.assertIs(gated_uops[-1].op, UOps.STORE) + self.assertIs(gated_uops[-1].op, Ops.STORE) @unittest.expectedFailure def test_gate_some_stores(self): - gmem0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) - gmem1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1) - gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4)) + gmem0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) + gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1) + gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4)) idx = gidx0*UOp.const(dtypes.int, 2) val = UOp.const(dtypes.float, 42.0) gate = gidx0.lt(UOp.const(dtypes.int, 1)) stores = [UOp.store(gmem0, idx, val, gate), UOp.store(gmem1, idx, val)] uops = linearize_uop(stores) if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops)) - if_uop = next(u for u in uops if u.op is UOps.IF) - endif = next(u for u in uops if u.op is UOps.ENDIF) + if_uop = next(u for u in uops if u.op is Ops.IF) + endif = next(u for u in uops if u.op is Ops.ENDIF) assert endif.src[0] is if_uop gated_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)]) self.assertEqual(len(gated_uops), 1) - self.assertIs(gated_uops[-1].op, UOps.STORE) + self.assertIs(gated_uops[-1].op, Ops.STORE) # scaled down version of TestLinearizerDumb.test_unmerged_ifs @unittest.expectedFailure def test_merge_ifs_alt(self): - gmem0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) - gmem1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1) - gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4)) + gmem0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) + gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1) + gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4)) idx = gidx0*UOp.const(dtypes.int, 2) val = UOp.const(dtypes.float, 42.0) gate = gidx0.lt(UOp.const(dtypes.int, 1)) stores = [UOp.store(gmem0, idx, val, gate), UOp.store(gmem1, idx, val, gate)] uops = linearize_uop(stores) if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops)) - ifs = [u for u in uops if u.op is UOps.IF] - endifs = [u for u in uops if u.op is UOps.ENDIF] + ifs = [u for u in uops if u.op is Ops.IF] + endifs = [u for u in uops if u.op is Ops.ENDIF] self.assertEqual(len(ifs), 1) self.assertEqual(len(endifs), 1) gated_uops = tuple(uops.uops[uops.uops.index(ifs[0])+1:uops.uops.index(endifs[0])]) self.assertEqual(len(gated_uops), 2) - for x in gated_uops: self.assertIs(x.op, UOps.STORE) + for x in gated_uops: self.assertIs(x.op, Ops.STORE) class TestLocalAccess(unittest.TestCase): # NOTE: this is failing on METAL CI, no idea why. Works locally. @@ -308,44 +308,44 @@ class TestLocalAccess(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") def test_local_basic(self): uops = [] - smem = uop(uops, UOps.DEFINE_LOCAL, dtypes.float32.ptr(local=True), (), ('smem', 16)) - st = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), uop(uops, UOps.CONST, dtypes.float32, (), 42.0))) - barr = uop(uops, UOps.BARRIER, dtypes.void, (st,)) - sres = uop(uops, UOps.LOAD, dtypes.float32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), barr)) + smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.float32.ptr(local=True), (), ('smem', 16)) + st = uop(uops, Ops.STORE, dtypes.void, (smem, uop(uops, Ops.CONST, dtypes.int32, (), 0), uop(uops, Ops.CONST, dtypes.float32, (), 42.0))) + barr = uop(uops, Ops.BARRIER, dtypes.void, (st,)) + sres = uop(uops, Ops.LOAD, dtypes.float32, (smem, uop(uops, Ops.CONST, dtypes.int32, (), 0), barr)) self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") def test_local_indirect(self): uops = [] - smem = uop(uops, UOps.DEFINE_LOCAL, dtypes.int32.ptr(local=True), (), ('smem', 16)) - st1 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), uop(uops, UOps.CONST, dtypes.int32, (), 2))) - st2 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 2), uop(uops, UOps.CONST, dtypes.int32, (), 42))) - barr = uop(uops, UOps.BARRIER, dtypes.void, (st1,st2)) - ofs = uop(uops, UOps.LOAD, dtypes.int32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), barr)) - sres = uop(uops, UOps.LOAD, dtypes.int32, (smem, ofs)) + smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.int32.ptr(local=True), (), ('smem', 16)) + st1 = uop(uops, Ops.STORE, dtypes.void, (smem, uop(uops, Ops.CONST, dtypes.int32, (), 1), uop(uops, Ops.CONST, dtypes.int32, (), 2))) + st2 = uop(uops, Ops.STORE, dtypes.void, (smem, uop(uops, Ops.CONST, dtypes.int32, (), 2), uop(uops, Ops.CONST, dtypes.int32, (), 42))) + barr = uop(uops, Ops.BARRIER, dtypes.void, (st1,st2)) + ofs = uop(uops, Ops.LOAD, dtypes.int32, (smem, uop(uops, Ops.CONST, dtypes.int32, (), 1), barr)) + sres = uop(uops, Ops.LOAD, dtypes.int32, (smem, ofs)) self.assertEqual(_test_uops_result(dtypes.int32, uops, sres), 42) @unittest.skipUnless(getenv("PTX"), "This only tests assembly backends") class TestAssembly(unittest.TestCase): def test_bitshift_left(self): - g1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0) - c1 = UOp(UOps.CONST, dtypes.int, (), 2) - c2 = UOp(UOps.CONST, dtypes.int, (), 3) - l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1)) - a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.MUL) - a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.MUL) + g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0) + c1 = UOp(Ops.CONST, dtypes.int, (), 2) + c2 = UOp(Ops.CONST, dtypes.int, (), 3) + l1 = UOp(Ops.LOAD, dtypes.int, (g1, c1)) + a1 = UOp(Ops.ALU, dtypes.int, (l1, c1), BinaryOps.MUL) + a2 = UOp(Ops.ALU, dtypes.int, (l1, c2), BinaryOps.MUL) uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer) Device[Device.DEFAULT].renderer.render("test", uops) self.assertEqual(uops[-1].arg, BinaryOps.SHL) self.assertEqual(uops[-2].arg, BinaryOps.MUL) def test_bitshift_right(self): - g1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0) - c1 = UOp(UOps.CONST, dtypes.int, (), 2) - c2 = UOp(UOps.CONST, dtypes.int, (), 3) - l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1)) - a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV) - a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV) + g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0) + c1 = UOp(Ops.CONST, dtypes.int, (), 2) + c2 = UOp(Ops.CONST, dtypes.int, (), 3) + l1 = UOp(Ops.LOAD, dtypes.int, (g1, c1)) + a1 = UOp(Ops.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV) + a2 = UOp(Ops.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV) uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer) Device[Device.DEFAULT].renderer.render("test", uops) self.assertEqual(uops[-1].arg, BinaryOps.SHR) @@ -354,38 +354,38 @@ class TestAssembly(unittest.TestCase): class TestUOpMethod(unittest.TestCase): @unittest.skip("uops lt no longer ordered") def test_compare_alu_same_src_different_arg(self): - a = UOp(UOps.CONST, dtypes.float, (), 2.0) - b = UOp(UOps.CONST, dtypes.float, (), 3.0) + a = UOp(Ops.CONST, dtypes.float, (), 2.0) + b = UOp(Ops.CONST, dtypes.float, (), 3.0) - add = UOp(UOps.ALU, dtypes.float, (a, b), BinaryOps.ADD) - mul = UOp(UOps.ALU, dtypes.float, (a, b), BinaryOps.MUL) + add = UOp(Ops.ALU, dtypes.float, (a, b), BinaryOps.ADD) + mul = UOp(Ops.ALU, dtypes.float, (a, b), BinaryOps.MUL) assert (add < mul) or (mul < add), "add and mul with same src should have an order" def test_uop_variables(self): a = UOp.variable("a", 1, 10) uop_var = UOp.const(dtypes.int, a) - st_var = UOp(UOps.LOAD, dtypes.float, (UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0), + st_var = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0), ShapeTracker.from_shape((2, a)).to_uop())) ast_vars = (st_var+uop_var).variables() self.assertEqual(len(ast_vars), 1) self.assertEqual(ast_vars[0], a) def test_const_factor(self): - gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 8)) - self.assertEqual(UOp(UOps.CONST, dtypes.int, (), 17).const_factor(), 17) + gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 8)) + self.assertEqual(UOp(Ops.CONST, dtypes.int, (), 17).const_factor(), 17) self.assertEqual(gidx0.const_factor(), 1) self.assertEqual((gidx0*3).const_factor(), 3) self.assertEqual((gidx0*3+6).const_factor(), 3) self.assertEqual((gidx0*3+1).const_factor(), 1) def test_replace(self): - x = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) + x = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) self.assertIs(x.replace(arg=None).arg, None) with self.assertRaises(AssertionError): x.replace(field="a") class TestUOpStr(unittest.TestCase): def test_uop_str(self): - a = UOp(UOps.CONST, dtypes.float, (), 2.0) + UOp(UOps.CONST, dtypes.float, (), 3.0) + a = UOp(Ops.CONST, dtypes.float, (), 2.0) + UOp(Ops.CONST, dtypes.float, (), 3.0) for _ in range(20): a = a + a assert len(str(a)) < 10_000, "exponential string growth" assert str(eval(str(a))) == str(a) @@ -394,11 +394,11 @@ class TestUOpStr(unittest.TestCase): t = t + t * Tensor.rand(10) # nice big complicated uop with Context(NOOPT=1): - sink = UOp(UOps.SINK, dtypes.void, (get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops[-1],)) + sink = UOp(Ops.SINK, dtypes.void, (get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops[-1],)) self.assertEqual(sink, eval(str(sink))) def test_vectorized_str(self): - vec = UOp(UOps.VECTORIZE, dtypes.int.vec(4), tuple(UOp.const(dtypes.int, x) for x in range(4))) + vec = UOp(Ops.VECTORIZE, dtypes.int.vec(4), tuple(UOp.const(dtypes.int, x) for x in range(4))) assert str(eval(str(vec))) == str(vec) @unittest.skip("uop no longer has order like this") @@ -406,23 +406,23 @@ class TestIndexingOrdering(unittest.TestCase): # NOTE: these tests skip type_verify since they add dtype to STORE @unittest.expectedFailure def test_simple_order(self): - buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) - st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) - st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) + st0 = UOp(Ops.STORE, dtypes.float.vec(4), (buf, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) + st1 = UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) uops = to_uops_list([st1, st0], skip_check=True) - stores = [st for st in uops if st.op is UOps.STORE] + stores = [st for st in uops if st.op is Ops.STORE] assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}" @unittest.expectedFailure def test_ordering_multi_output(self): - buf0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) - buf1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1) - st0_0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf0, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) - st1_0 = UOp(UOps.STORE, dtypes.float, (buf0, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) - st0_1 = UOp(UOps.STORE, dtypes.float.vec(4), (buf1, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) - st1_1 = UOp(UOps.STORE, dtypes.float, (buf1, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) + buf0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) + buf1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1) + st0_0 = UOp(Ops.STORE, dtypes.float.vec(4), (buf0, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) + st1_0 = UOp(Ops.STORE, dtypes.float, (buf0, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) + st0_1 = UOp(Ops.STORE, dtypes.float.vec(4), (buf1, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) + st1_1 = UOp(Ops.STORE, dtypes.float, (buf1, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) uops = to_uops_list([st0_0, st1_0, st0_1, st1_1], skip_check=True) - stores = [st for st in uops if st.op is UOps.STORE] + stores = [st for st in uops if st.op is Ops.STORE] print("\n".join(map(str, stores))) # buf0 stores come first self.assertEqual(stores[0].src[0].arg, stores[1].src[0].arg) @@ -433,12 +433,12 @@ class TestIndexingOrdering(unittest.TestCase): assert stores[2].src[1] < stores[3].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}" def test_simple_order_with_special(self): - buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) - gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4)) - st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) - st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) + buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) + gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4)) + st0 = UOp(Ops.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) + st1 = UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) uops = linearize_uop(UOp.sink(st1, st0), skip_check=True) - stores = [st for st in uops if st.op is UOps.STORE] + stores = [st for st in uops if st.op is Ops.STORE] assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}" class TestUPatHelpers(unittest.TestCase): @@ -447,7 +447,7 @@ class TestUPatHelpers(unittest.TestCase): self.assertEqual(to_si.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "schedule.py") self.assertEqual(spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "ops.py") with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*? - test_upat = UPat(UOps.CONST, dtypes.bool) + test_upat = UPat(Ops.CONST, dtypes.bool) self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1]) if __name__ == '__main__': diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index 806e5db328..25865639ab 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -4,7 +4,7 @@ from tinygrad.helpers import getenv, GlobalCounters from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import lower_schedule_item from tinygrad.codegen.linearize import linearize_uop -from tinygrad.ops import BinaryOps, TernaryOps, flops_mem, UOps, UOp +from tinygrad.ops import BinaryOps, TernaryOps, flops_mem, Ops, UOp from tinygrad.dtype import dtypes from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError @@ -119,23 +119,23 @@ class TestUOpsStats(unittest.TestCase): #MULACC should have the same stats as MUL + ADD def test_mulacc(self): - globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), tuple()) - o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1) - o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2) - u1 = UOp(UOps.LOAD, dtypes.int, (globl.index(o1),)) - u2 = UOp(UOps.LOAD, dtypes.int, (globl.index(o2),)) - u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3) - u4 = UOp(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL) - u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD) + globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple()) + o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1) + o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2) + u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),)) + u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),)) + u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3) + u4 = UOp(Ops.ALU, dtypes.int, (u1,u2), BinaryOps.MUL) + u5 = UOp(Ops.ALU, dtypes.int, (u4,u3), BinaryOps.ADD) uops = linearize_uop(u5.sink()) - globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), tuple()) - o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1) - o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2) - u1 = UOp(UOps.LOAD, dtypes.int, (globl.index(o1),)) - u2 = UOp(UOps.LOAD, dtypes.int, (globl.index(o2),)) - u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3) - u4 = UOp(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC) + globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple()) + o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1) + o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2) + u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),)) + u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),)) + u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3) + u4 = UOp(Ops.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC) uops_fma = linearize_uop(u4.sink()) self.assertEqual(flops_mem(uops), flops_mem(uops_fma)) diff --git a/test/test_viz.py b/test/test_viz.py index 936e89f84b..02f8820cc5 100644 --- a/test/test_viz.py +++ b/test/test_viz.py @@ -1,7 +1,7 @@ from typing import Dict, List, Optional import unittest from tinygrad.dtype import dtypes -from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, UOps, UPat, \ +from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, \ graph_rewrite, contexts, track_rewrites from tinygrad.viz.serve import get_details, get_metadata, uop_to_json @@ -27,7 +27,7 @@ class TestViz(unittest.TestCase): pm = PatternMatcher([ (UPat.var("x")*1, lambda x:x), ]) - a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) + a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) uops = helper_test_viz(a*1, pm) self.assertEqual(len(uops), 1) self.assertEqual(uops[0], a) @@ -37,21 +37,21 @@ class TestViz(unittest.TestCase): (UPat.var("x")+UPat.var("x"), lambda x:x*2), (UPat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))), ]) - a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) + a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) uops = helper_test_viz(a+a, pm) self.assertEqual(len(uops), 2) self.assertEqual(uops[0], a*2) self.assertEqual(uops[1], graph_rewrite(a+a, pm)) def test_rewrite_with_ctx(self): - a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) - b = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1), UOp.const(dtypes.int, 0))) + a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) + b = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1), UOp.const(dtypes.int, 0))) def store_load(ctx:Dict[UOp, None], x:UOp) -> Optional[UOp]: if x in ctx: return None ctx[x] = None return UOp.store(*x.src, x) pm = PatternMatcher([ - (UPat(UOps.LOAD, name="x"), store_load), + (UPat(Ops.LOAD, name="x"), store_load), ]) uops = helper_test_viz(a+b, pm, {}) self.assertEqual(len(uops), 2) @@ -61,7 +61,7 @@ class TestViz(unittest.TestCase): simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)]) @track_rewrites(named=True) def do_rewrite(x:UOp): return graph_rewrite(x, simple) - ld = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0))) + ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0))) do_rewrite(ld*1) do_rewrite(ld*2) ret = get_metadata(contexts) @@ -79,13 +79,13 @@ class TestViz(unittest.TestCase): def do_rewrite(x:UOp): x = graph_rewrite(x, simple) # NOTE: viz tracks this raise Exception("test") - ld = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0))) + ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0))) with self.assertRaises(Exception): do_rewrite(ld*1) ret = get_metadata(contexts) self.assertEqual(len(ret), 1) def test_fold_const(self): - a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) + a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) graph = uop_to_json(a) assert not any(v[0].startswith("CONST") for v in graph.values()) assert len([x for x in graph.values() if "CONST" in x[0]]) == 1 diff --git a/test/test_winograd.py b/test/test_winograd.py index f0b4dc2211..dcf1f9411c 100644 --- a/test/test_winograd.py +++ b/test/test_winograd.py @@ -1,6 +1,6 @@ import unittest from tinygrad import Tensor, GlobalCounters -from tinygrad.ops import UOps +from tinygrad.ops import Ops from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv from tinygrad.codegen.kernel import Kernel from tinygrad.engine.schedule import create_schedule @@ -23,7 +23,7 @@ class TestWinograd(unittest.TestCase): sched = create_schedule([out.lazydata]) for i,s in enumerate(sched): - if s.ast.op is not UOps.SINK: continue + if s.ast.op is not Ops.SINK: continue ops = s.ast.parents with Timing(f"linearize {i} with {len(ops):4d} ops: "): l = Kernel(s.ast) diff --git a/test/unit/test_graph_rewrite.py b/test/unit/test_graph_rewrite.py index e29016e476..4cd031efda 100644 --- a/test/unit/test_graph_rewrite.py +++ b/test/unit/test_graph_rewrite.py @@ -1,7 +1,7 @@ import unittest, math from tinygrad import dtypes from tinygrad.helpers import all_same -from tinygrad.ops import UOp, UOps, BinaryOps, exec_alu +from tinygrad.ops import UOp, Ops, BinaryOps, exec_alu from tinygrad.codegen.uopgraph import full_graph_rewrite # Helper function to apply the graph rewrite @@ -9,12 +9,12 @@ def apply_rewrite(expr): return full_graph_rewrite(expr.sink()).src[0] def evaluate_uop(uop, variables): - if uop.op == UOps.CONST: + if uop.op == Ops.CONST: return uop.arg - elif uop.op == UOps.DEFINE_VAR: + elif uop.op == Ops.DEFINE_VAR: var_name = uop.arg[0] return variables[var_name] - elif uop.op == UOps.ALU: + elif uop.op == Ops.ALU: src_values = [evaluate_uop(src, variables) for src in uop.src] return exec_alu(uop.arg, uop.dtype, src_values) else: @@ -23,12 +23,12 @@ def evaluate_uop(uop, variables): class TestArithmeticSimplifications(unittest.TestCase): def test_full_graph_rewrite_division_by_zero(self): optimized_div_uop = apply_rewrite(UOp.const(dtypes.float32, 10.0) / UOp.const(dtypes.float32, 0.0)) - self.assertEqual(optimized_div_uop.op, UOps.CONST) + self.assertEqual(optimized_div_uop.op, Ops.CONST) self.assertTrue(math.isinf(optimized_div_uop.arg) or math.isnan(optimized_div_uop.arg)) def test_full_graph_rewrite_redundant_operations(self): optimized_uop = apply_rewrite((UOp.const(dtypes.float32, 10.0) + UOp.const(dtypes.float32, 0.0)) * UOp.const(dtypes.float32, 1.0)) - self.assertEqual(optimized_uop.op, UOps.CONST) + self.assertEqual(optimized_uop.op, Ops.CONST) self.assertEqual(optimized_uop.arg, 10.0) def test_full_graph_rewrite_large_graph(self): @@ -36,17 +36,17 @@ class TestArithmeticSimplifications(unittest.TestCase): for i in range(1, 101): prev_uop += UOp.const(dtypes.int32, i) optimized_uop = apply_rewrite(prev_uop) - self.assertEqual(optimized_uop.op, UOps.CONST) + self.assertEqual(optimized_uop.op, Ops.CONST) self.assertEqual(optimized_uop.arg, sum(range(1, 101))) def test_full_graph_rewrite_division_by_one(self): optimized_uop = apply_rewrite(UOp.const(dtypes.float32, 42.0) / UOp.const(dtypes.float32, 1.0)) - self.assertEqual(optimized_uop.op, UOps.CONST) + self.assertEqual(optimized_uop.op, Ops.CONST) self.assertEqual(optimized_uop.arg, 42.0) def test_full_graph_rewrite_modulo_by_one(self): optimized_uop = apply_rewrite(UOp.const(dtypes.int32, 42) % UOp.const(dtypes.int32, 1)) - self.assertEqual(optimized_uop.op, UOps.CONST) + self.assertEqual(optimized_uop.op, Ops.CONST) self.assertEqual(optimized_uop.arg, 0) @@ -90,7 +90,7 @@ class TestFoldingAndReduction(unittest.TestCase): inner_range = UOp.range(dtypes.int32, 0, 4, 1) expr = (outer_range * 10) + inner_range optimized_reduce_uop = apply_rewrite(expr.reduce(BinaryOps.ADD, outer_range, inner_range)) - self.assertEqual(optimized_reduce_uop.op, UOps.CONST) + self.assertEqual(optimized_reduce_uop.op, Ops.CONST) self.assertEqual(optimized_reduce_uop.arg, sum((i * 10) + j for i in range(8) for j in range(4))) @@ -98,32 +98,32 @@ class TestModuloAndDivisionFolding(unittest.TestCase): def test_full_graph_rewrite_modulo_folding_with_define_var(self): x_var_uop = UOp.variable('x', 0, 100) optimized_mod_uop = apply_rewrite(((x_var_uop * 4) + 2) % 4) - self.assertEqual(optimized_mod_uop.op, UOps.CONST) + self.assertEqual(optimized_mod_uop.op, Ops.CONST) self.assertEqual(optimized_mod_uop.arg, 2) def test_full_graph_rewrite_division_folding_with_define_var(self): n_var_uop = UOp.variable('n', 1, 1000) optimized_div_uop = apply_rewrite((n_var_uop * 6) // 3) - self.assertEqual(optimized_div_uop.op, UOps.ALU) + self.assertEqual(optimized_div_uop.op, Ops.ALU) self.assertEqual(optimized_div_uop.arg, BinaryOps.MUL) self.assertEqual(optimized_div_uop.src[1].arg, 2) def test_full_graph_rewrite_complex_mod_div_folding(self): k_var_uop = UOp.variable('k', 0, 50) optimized_div_uop = apply_rewrite(((k_var_uop * 12 + 8) % 6) // 2) - self.assertEqual(optimized_div_uop.op, UOps.CONST) + self.assertEqual(optimized_div_uop.op, Ops.CONST) self.assertEqual(optimized_div_uop.arg, 1) def test_graph_rewrite_div_folding_bug(self): - lhs = UOp(UOps.ALU, dtypes.int.vec(4), arg=BinaryOps.ADD, src=( - UOp(UOps.VECTORIZE, dtypes.int.vec(4), arg=None, src=(UOp(UOps.SPECIAL, dtypes.int, arg=('lidx0', 32), src=()),)*4), - UOp(UOps.VCONST, dtypes.int.vec(4), arg=(0, 256, 512, 768), src=()))) + lhs = UOp(Ops.ALU, dtypes.int.vec(4), arg=BinaryOps.ADD, src=( + UOp(Ops.VECTORIZE, dtypes.int.vec(4), arg=None, src=(UOp(Ops.SPECIAL, dtypes.int, arg=('lidx0', 32), src=()),)*4), + UOp(Ops.VCONST, dtypes.int.vec(4), arg=(0, 256, 512, 768), src=()))) rhs = UOp.const(dtypes.int.vec(4), 2) unopt = lhs.lt(rhs) opt = apply_rewrite(unopt) print(unopt) print(opt) - if opt.op is UOps.VECTORIZE: self.assertFalse(all_same(opt.src)) + if opt.op is Ops.VECTORIZE: self.assertFalse(all_same(opt.src)) def test_full_graph_rewrite_modulo_large_divisor(self): x_var_uop = UOp.variable('x', 1, 5) @@ -180,12 +180,12 @@ class TestGEPAndVectorizeRewrite(unittest.TestCase): def test_gep_on_vconst(self): # GEP on a VCONST to extract a single element - vconst = UOp(UOps.VCONST, dtypes.float32.vec(4), arg=(1.0, 2.0, 3.0, 4.0)) + vconst = UOp(Ops.VCONST, dtypes.float32.vec(4), arg=(1.0, 2.0, 3.0, 4.0)) self.assertEqual(apply_rewrite(vconst.gep(2)).arg, 3.0) def test_gep_tuple_on_vconst(self): # GEP on a VCONST using a tuple to extract multiple elements - vconst = UOp(UOps.VCONST, dtypes.float32.vec(4), arg=(7.0, 8.0, 9.0, 10.0)) + vconst = UOp(Ops.VCONST, dtypes.float32.vec(4), arg=(7.0, 8.0, 9.0, 10.0)) optimized_uop = apply_rewrite(vconst.gep((1, 3))) self.assertEqual([sub_uop.arg for sub_uop in optimized_uop.src], [8.0, 10.0]) @@ -198,7 +198,7 @@ class TestGEPAndVectorizeRewrite(unittest.TestCase): def test_vectorize_multiple_elements(self): # Vectorizing multiple elements using GEP base_vector = UOp.const(dtypes.float32.vec(4), (5.0, 10.0, 15.0, 20.0)) - vectorized_uop = UOp(UOps.VECTORIZE, dtypes.float32.vec(4), src=(base_vector.gep(0), base_vector.gep(1), base_vector.gep(2), base_vector.gep(3))) + vectorized_uop = UOp(Ops.VECTORIZE, dtypes.float32.vec(4), src=(base_vector.gep(0), base_vector.gep(1), base_vector.gep(2), base_vector.gep(3))) optimized_uop = apply_rewrite(vectorized_uop) self.assertEqual([sub_uop.arg for sub_uop in optimized_uop.src], [5.0, 10.0, 15.0, 20.0]) diff --git a/test/unit/test_pattern_matcher.py b/test/unit/test_pattern_matcher.py index 22c9c6f630..b4a1715bcd 100644 --- a/test/unit/test_pattern_matcher.py +++ b/test/unit/test_pattern_matcher.py @@ -1,13 +1,13 @@ import unittest, itertools from tinygrad.dtype import dtypes -from tinygrad.ops import UOps, UOp, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401 +from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401 from tinygrad.ops import PatternMatcher, UPat class TestPatternMatcher(unittest.TestCase): def test_simple_match(self): - matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float), lambda x: x)]) - c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) - c2 = UOp(UOps.CONST, dtypes.int, arg=1) + matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype=dtypes.float), lambda x: x)]) + c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) + c2 = UOp(Ops.CONST, dtypes.int, arg=1) self.assertEqual(matcher.rewrite(c1), c1) self.assertEqual(matcher.rewrite(c2), None) @@ -30,9 +30,9 @@ class TestPatternMatcher(unittest.TestCase): nonlocal match_cnt match_cnt += 1 assert len(x.src) == 0 - return UOp(UOps.CONST, src=(UOp(UOps.CONST),)) - matcher = PatternMatcher([(UPat(UOps.CONST, src=(), name="x"), fxn)]) - c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) + return UOp(Ops.CONST, src=(UOp(Ops.CONST),)) + matcher = PatternMatcher([(UPat(Ops.CONST, src=(), name="x"), fxn)]) + c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) # second rewrite shouldn't match anything c1 = matcher.rewrite(c1) c1 = matcher.rewrite(c1) @@ -42,9 +42,9 @@ class TestPatternMatcher(unittest.TestCase): def fxn(ctx, x): ctx.append(True) assert len(x.src) == 0 - return UOp(UOps.CONST, src=(UOp(UOps.CONST),)) - matcher = PatternMatcher([(UPat(UOps.CONST, src=(), name="x"), fxn)]) - c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) + return UOp(Ops.CONST, src=(UOp(Ops.CONST),)) + matcher = PatternMatcher([(UPat(Ops.CONST, src=(), name="x"), fxn)]) + c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) # second rewrite shouldn't match anything ctx = [] c1 = matcher.rewrite(c1, ctx) @@ -52,33 +52,33 @@ class TestPatternMatcher(unittest.TestCase): self.assertEqual(len(ctx), 1) def test_uop(self): - matcher = PatternMatcher([(UPat(UOps.CONST, name="x"), lambda x: x)]) - c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) - c2 = UOp(UOps.ALU, dtypes.float, (c1, c1), BinaryOps.ADD) + matcher = PatternMatcher([(UPat(Ops.CONST, name="x"), lambda x: x)]) + c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) + c2 = UOp(Ops.ALU, dtypes.float, (c1, c1), BinaryOps.ADD) self.assertEqual(matcher.rewrite(c1), c1) self.assertEqual(matcher.rewrite(c2), None) def test_uop_set(self): - matcher = PatternMatcher([(UPat({UOps.CONST, UOps.CAST}, name="x"), lambda x: x)]) - c1 = UOp(UOps.CONST, dtypes.bool, arg=False) - c2 = UOp(UOps.CAST, dtypes.int, (c1,)) - c3 = UOp(UOps.CONST, dtypes.float, arg=1.0) - c4 = UOp(UOps.ALU, dtypes.float, (c3, c3), BinaryOps.ADD) + matcher = PatternMatcher([(UPat({Ops.CONST, Ops.CAST}, name="x"), lambda x: x)]) + c1 = UOp(Ops.CONST, dtypes.bool, arg=False) + c2 = UOp(Ops.CAST, dtypes.int, (c1,)) + c3 = UOp(Ops.CONST, dtypes.float, arg=1.0) + c4 = UOp(Ops.ALU, dtypes.float, (c3, c3), BinaryOps.ADD) self.assertEqual(matcher.rewrite(c1), c1) self.assertEqual(matcher.rewrite(c2), c2) self.assertEqual(matcher.rewrite(c4), None) def test_arg(self): matcher = PatternMatcher([ - (UPat(UOps.CONST, arg=0, name="x"), lambda x: x), - (UPat(UOps.CONST, arg=False, name="x"), lambda x: x), - (UPat(UOps.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x), + (UPat(Ops.CONST, arg=0, name="x"), lambda x: x), + (UPat(Ops.CONST, arg=False, name="x"), lambda x: x), + (UPat(Ops.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x), ]) - c1 = UOp(UOps.CONST, dtypes.float, arg=0.0) - c2 = UOp(UOps.CONST, dtypes.bool, arg=False) - c3 = UOp(UOps.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MAX) - c4 = UOp(UOps.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MUL) - c5 = UOp(UOps.CONST, dtypes.int, arg=-1) + c1 = UOp(Ops.CONST, dtypes.float, arg=0.0) + c2 = UOp(Ops.CONST, dtypes.bool, arg=False) + c3 = UOp(Ops.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MAX) + c4 = UOp(Ops.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MUL) + c5 = UOp(Ops.CONST, dtypes.int, arg=-1) self.assertEqual(matcher.rewrite(c1), c1) self.assertEqual(matcher.rewrite(c2), c2) self.assertEqual(matcher.rewrite(c3), c3) @@ -87,17 +87,17 @@ class TestPatternMatcher(unittest.TestCase): def test_filter_arg(self): matcher = PatternMatcher([ - (UPat(UOps.ALU, arg=BinaryOps.MUL, src=[UPat(UOps.CONST, name="c"), UPat(UOps.CONST, arg=2)], name="x"), + (UPat(Ops.ALU, arg=BinaryOps.MUL, src=[UPat(Ops.CONST, name="c"), UPat(Ops.CONST, arg=2)], name="x"), lambda x,c: x if c.arg in {1, -1} else None) ]) - y1 = UOp(UOps.CONST, dtypes.int, arg=1) - y2 = UOp(UOps.CONST, dtypes.int, arg=2) - y3 = UOp(UOps.CONST, dtypes.int, arg=-1) - c1 = UOp(UOps.ALU, dtypes.int, (y1, y2), BinaryOps.MUL) - c2 = UOp(UOps.ALU, dtypes.int, (y2, y2), BinaryOps.MUL) - c3 = UOp(UOps.ALU, dtypes.int, (y3, y2), BinaryOps.MUL) - c4 = UOp(UOps.ALU, dtypes.int, (y2, y1), BinaryOps.MUL) - c5 = UOp(UOps.ALU, dtypes.int, (y2, y3), BinaryOps.MUL) + y1 = UOp(Ops.CONST, dtypes.int, arg=1) + y2 = UOp(Ops.CONST, dtypes.int, arg=2) + y3 = UOp(Ops.CONST, dtypes.int, arg=-1) + c1 = UOp(Ops.ALU, dtypes.int, (y1, y2), BinaryOps.MUL) + c2 = UOp(Ops.ALU, dtypes.int, (y2, y2), BinaryOps.MUL) + c3 = UOp(Ops.ALU, dtypes.int, (y3, y2), BinaryOps.MUL) + c4 = UOp(Ops.ALU, dtypes.int, (y2, y1), BinaryOps.MUL) + c5 = UOp(Ops.ALU, dtypes.int, (y2, y3), BinaryOps.MUL) self.assertEqual(matcher.rewrite(c1), c1) self.assertEqual(matcher.rewrite(c2), None) self.assertEqual(matcher.rewrite(c3), c3) @@ -105,37 +105,37 @@ class TestPatternMatcher(unittest.TestCase): self.assertEqual(matcher.rewrite(c5), c5) def test_dup_name(self): - matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST, name="y"), UPat(UOps.CONST, name="y"))), lambda x, y: x)]) - y1 = UOp(UOps.CONST, dtypes.float, arg=1.0) - y2 = UOp(UOps.CONST, dtypes.float, arg=1.0) - c1 = UOp(UOps.ALU, dtypes.float, (y1, y1), BinaryOps.ADD) - c2 = UOp(UOps.ALU, dtypes.float, (y1, y2), BinaryOps.ADD) + matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST, name="y"), UPat(Ops.CONST, name="y"))), lambda x, y: x)]) + y1 = UOp(Ops.CONST, dtypes.float, arg=1.0) + y2 = UOp(Ops.CONST, dtypes.float, arg=1.0) + c1 = UOp(Ops.ALU, dtypes.float, (y1, y1), BinaryOps.ADD) + c2 = UOp(Ops.ALU, dtypes.float, (y1, y2), BinaryOps.ADD) self.assertEqual(matcher.rewrite(c1), c1) self.assertEqual(matcher.rewrite(c2), c1) def test_dtype(self): - matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float32), lambda x: x)]) - c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) - c2 = UOp(UOps.CONST, dtypes.float64, arg=1.0) + matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype=dtypes.float32), lambda x: x)]) + c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) + c2 = UOp(Ops.CONST, dtypes.float64, arg=1.0) self.assertEqual(matcher.rewrite(c1), c1) self.assertEqual(matcher.rewrite(c2), None) def test_dtype_set(self): - matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype={dtypes.float32, dtypes.float64}), lambda x: x)]) - c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) - c2 = UOp(UOps.CONST, dtypes.float64, arg=1.0) - c3 = UOp(UOps.CONST, dtypes.float16, arg=1.0) - c4 = UOp(UOps.CONST, dtypes.int, arg=1) + matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype={dtypes.float32, dtypes.float64}), lambda x: x)]) + c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) + c2 = UOp(Ops.CONST, dtypes.float64, arg=1.0) + c3 = UOp(Ops.CONST, dtypes.float16, arg=1.0) + c4 = UOp(Ops.CONST, dtypes.int, arg=1) self.assertEqual(matcher.rewrite(c1), c1) self.assertEqual(matcher.rewrite(c2), c2) self.assertEqual(matcher.rewrite(c3), None) self.assertEqual(matcher.rewrite(c4), None) def test_src_one(self): - matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.CONST))), lambda x: x)]) - c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) - c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) - c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) + matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST), UPat(Ops.CONST))), lambda x: x)]) + c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) + c2 = UOp(Ops.CONST, dtypes.float, arg=2.0) + c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) self.assertEqual(matcher.rewrite(c3), c3) self.assertEqual(matcher.rewrite(c2), None) # that CONST/ALU -> ALU/CONST rewrite is now instant @@ -149,46 +149,46 @@ class TestPatternMatcher(unittest.TestCase): """ def test_src_permutations(self): - matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=[UPat(UOps.CONST), UPat(UOps.ALU)]), lambda x: x)]) - c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) - c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) - c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) - c4 = UOp(UOps.ALU, dtypes.float, (c3,c2), BinaryOps.ADD) - c5 = UOp(UOps.ALU, dtypes.float, (c2,c3), BinaryOps.ADD) - c6 = UOp(UOps.ALU, dtypes.float, (c3,c4), BinaryOps.ADD) + matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=[UPat(Ops.CONST), UPat(Ops.ALU)]), lambda x: x)]) + c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) + c2 = UOp(Ops.CONST, dtypes.float, arg=2.0) + c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) + c4 = UOp(Ops.ALU, dtypes.float, (c3,c2), BinaryOps.ADD) + c5 = UOp(Ops.ALU, dtypes.float, (c2,c3), BinaryOps.ADD) + c6 = UOp(Ops.ALU, dtypes.float, (c3,c4), BinaryOps.ADD) self.assertEqual(matcher.rewrite(c3), None) self.assertEqual(matcher.rewrite(c4), c4) self.assertEqual(matcher.rewrite(c5), c5) self.assertEqual(matcher.rewrite(c6), None) def test_src_repeat(self): - matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=UPat(UOps.CONST)), lambda x: x)]) - c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) - c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) - c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) - c4 = UOp(UOps.ALU, dtypes.float, (c2,c3), BinaryOps.ADD) + matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=UPat(Ops.CONST)), lambda x: x)]) + c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) + c2 = UOp(Ops.CONST, dtypes.float, arg=2.0) + c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) + c4 = UOp(Ops.ALU, dtypes.float, (c2,c3), BinaryOps.ADD) self.assertEqual(matcher.rewrite(c3), c3) self.assertEqual(matcher.rewrite(c4), None) def test_allow_len(self): - matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)]) - c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) - c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) - c3 = UOp(UOps.CONST, dtypes.float, arg=3.0) - c4 = UOp(UOps.ALU, dtypes.float, (c1,), UnaryOps.EXP2) - c5 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) - c6 = UOp(UOps.ALU, dtypes.float, (c1,c2,c3), TernaryOps.MULACC) + matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)]) + c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) + c2 = UOp(Ops.CONST, dtypes.float, arg=2.0) + c3 = UOp(Ops.CONST, dtypes.float, arg=3.0) + c4 = UOp(Ops.ALU, dtypes.float, (c1,), UnaryOps.EXP2) + c5 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD) + c6 = UOp(Ops.ALU, dtypes.float, (c1,c2,c3), TernaryOps.MULACC) self.assertEqual(matcher.rewrite(c4), None) self.assertEqual(matcher.rewrite(c5), None) self.assertEqual(matcher.rewrite(c6), c6) def test_deep_src_permutations(self): - c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) - c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) + c1 = UOp(Ops.CONST, dtypes.float, arg=1.0) + c2 = UOp(Ops.CONST, dtypes.float, arg=2.0) u1 = (c1 + c2) + c1 u2 = (c2 + c1) + c1 matcher = PatternMatcher([ - (UPat(UOps.ALU, src=[UPat(UOps.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')]), lambda a,b: b) + (UPat(Ops.ALU, src=[UPat(Ops.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')]), lambda a,b: b) ]) self.assertIsNotNone(matcher.rewrite(u1)) self.assertIsNotNone(matcher.rewrite(u2)) diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index 42d040e1ad..78713a43b3 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -5,14 +5,14 @@ from tinygrad.dtype import dtypes from tinygrad.helpers import prod from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad import Variable -from tinygrad.ops import UOp, UOps, graph_rewrite +from tinygrad.ops import UOp, Ops, graph_rewrite from tinygrad.codegen.uopgraph import sym from itertools import product def shapetracker_getitem(st:ShapeTracker, val:int): idx, valid = st.reshape((st.size,)).to_indexed_uops([UOp.const(dtypes.int, val)]) idx, valid = graph_rewrite(idx, sym), graph_rewrite(valid, sym) - assert idx.op is UOps.CONST and valid.op is UOps.CONST + assert idx.op is Ops.CONST and valid.op is Ops.CONST return idx.arg, valid.arg class CheckingShapeTracker: diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index 6601b5c9b7..931baf1405 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -3,27 +3,27 @@ from typing import Tuple from tinygrad.codegen.uopgraph import full_graph_rewrite, is_increasing from tinygrad.dtype import dtypes -from tinygrad.ops import UOp, UOps, simplify_valid +from tinygrad.ops import UOp, Ops, simplify_valid def get_gated_load_uop(valid:UOp, idx:UOp): - return UOp(UOps.LOAD, dtypes.float, ( - UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), + return UOp(Ops.LOAD, dtypes.float, ( + UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), idx, UOp.const(dtypes.float, 0.0), valid )) def get_load_image_uop(image_shape:Tuple[int, ...], valid:UOp, idx:Tuple[UOp, UOp]): - return UOp(UOps.LOAD, dtypes.float.vec(4), ( - UOp(UOps.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0), - UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx), - UOp(UOps.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0),)*4), + return UOp(Ops.LOAD, dtypes.float.vec(4), ( + UOp(Ops.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0), + UOp(Ops.VECTORIZE, dtypes.int.vec(2), idx), + UOp(Ops.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0),)*4), valid )) -def Special(expr, nmax): return UOp(UOps.SPECIAL, dtypes.int, (), (expr, nmax)) +def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.int, (), (expr, nmax)) def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax) -def Range(n, nmax): return UOp(UOps.RANGE, dtypes.int, arg=(n, True), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, nmax),)) +def Range(n, nmax): return UOp(Ops.RANGE, dtypes.int, arg=(n, True), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, nmax),)) class TestHelpers(unittest.TestCase): def test_is_increasing(self): @@ -43,7 +43,7 @@ class TestHelpers(unittest.TestCase): self.assertTrue(is_increasing(f2)) self.assertTrue(is_increasing(f3)) - rng = UOp(UOps.RANGE, dtypes.int, arg=(2, True), src=(UOp(UOps.CONST, dtypes.int, arg=0, src=()), UOp(UOps.CONST, dtypes.int, arg=5, src=()),)) + rng = UOp(Ops.RANGE, dtypes.int, arg=(2, True), src=(UOp(Ops.CONST, dtypes.int, arg=0, src=()), UOp(Ops.CONST, dtypes.int, arg=5, src=()),)) self.assertTrue(is_increasing(rng)) self.assertTrue(is_increasing(rng+2)) @@ -87,7 +87,7 @@ class TestImageSimplification(unittest.TestCase): def check(self, load, svalid, sidx0, sidx1): load = full_graph_rewrite(load.sink()).src[0] idx = load.src[0].src[1] - self.assertEqual(idx.op, UOps.VECTORIZE) + self.assertEqual(idx.op, Ops.VECTORIZE) self.assertEqual(len(idx.src), 2) idx0, idx1 = idx.src[0], idx.src[1] self.assertEqual(idx0.render(simplify=False), sidx0) @@ -152,7 +152,7 @@ class TestImageSimplification(unittest.TestCase): # empty -> invalid load = get_load_image_uop(shape, (gidx0).lt(8) & (gidx0).lt(8).ne(True), idx) load = full_graph_rewrite(load.sink()).src[0] - self.assertEqual(load.op, UOps.VECTORIZE) + self.assertEqual(load.op, Ops.VECTORIZE) self.assertEqual(load.dtype.count, 4) def test_openpilot_conv1(self): diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index b3d9a7648e..d50ee7805e 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -9,15 +9,15 @@ from typing import Tuple from tinygrad.dtype import dtypes, ConstType from tinygrad.codegen.linearize import linearize_uop from tinygrad.codegen.uopgraph import full_graph_rewrite, sym -from tinygrad.ops import UOp, UOps, graph_rewrite, sym_infer +from tinygrad.ops import UOp, Ops, graph_rewrite, sym_infer from tinygrad import Variable import functools def render(self) -> Tuple[str, ConstType, ConstType]: # NOTE: we need STORE so the ALU op has children - glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0) - uops = linearize_uop(full_graph_rewrite(UOp(UOps.STORE, dtypes.void, (glbl, UOp.const(dtypes.int, 0), self)).sink())) - rewritten_uop = [uop for uop in uops if uop.op is UOps.STORE][0].src[-1] + glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0) + uops = linearize_uop(full_graph_rewrite(UOp(Ops.STORE, dtypes.void, (glbl, UOp.const(dtypes.int, 0), self)).sink())) + rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[-1] return rewritten_uop.render(simplify=False), rewritten_uop.vmin, rewritten_uop.vmax def NumNode(val): return UOp.const(dtypes.int, val) diff --git a/test/unit/test_uop_vmin_vmax.py b/test/unit/test_uop_vmin_vmax.py index bda822075a..ed6b6c67b3 100644 --- a/test/unit/test_uop_vmin_vmax.py +++ b/test/unit/test_uop_vmin_vmax.py @@ -1,5 +1,5 @@ import unittest, math -from tinygrad.ops import UOp, UOps +from tinygrad.ops import UOp, Ops from tinygrad.dtype import dtypes class TestVminVmaxProperties(unittest.TestCase): @@ -35,7 +35,7 @@ class TestVminVmaxProperties(unittest.TestCase): def test_vmin_vmax_multiplication_0_inf(self): # vmin and vmax for multiplication with a variable x = UOp.const(dtypes.float, 0.0) - y = UOp.load(UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0), UOp.const(dtypes.int, 0), dtype=dtypes.float) + y = UOp.load(UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0), UOp.const(dtypes.int, 0), dtype=dtypes.float) uop = x * y # TODO: these should be 0, but definitely should not be nan self.assertEqual(uop.vmin, -math.inf) diff --git a/test/unit/test_verify_ast.py b/test/unit/test_verify_ast.py index b758e31f96..27baa6a0ef 100644 --- a/test/unit/test_verify_ast.py +++ b/test/unit/test_verify_ast.py @@ -4,7 +4,7 @@ import unittest from tinygrad import Tensor from tinygrad.codegen.kernel import Kernel from tinygrad.helpers import DEBUG -from tinygrad.ops import UOp, UOps, ReduceOps, print_uops +from tinygrad.ops import UOp, Ops, ReduceOps, print_uops from tinygrad.codegen.kernel import verify_ast from tinygrad.shape.shapetracker import ShapeTracker from tinygrad import dtypes @@ -12,7 +12,7 @@ from tinygrad.shape.view import View class InvalidASTException(Exception): pass def helper_test_verify_ast(*stores:UOp) -> Kernel: - sink = UOp(UOps.SINK, dtypes.void, stores) + sink = UOp(Ops.SINK, dtypes.void, stores) if DEBUG >= 3: for op in stores: print(op) try: verify_ast(sink) @@ -26,59 +26,59 @@ def helper_test_verify_ast(*stores:UOp) -> Kernel: class TestVerifyAST(unittest.TestCase): def test_tiny_add(self): dtype = dtypes.int - buf_0 = UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), 0) - buf_1 = UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), 1) - buf_2 = UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), 2) - a = UOp(UOps.LOAD, dtype, (buf_1, ShapeTracker.from_shape((32, 1)).to_uop())) - b = UOp(UOps.LOAD, dtype, (buf_2, ShapeTracker.from_shape((32, 1)).to_uop())) - store = UOp(UOps.STORE, dtypes.void, (buf_0, ShapeTracker.from_shape((32, 1)).to_uop(), a+b)) + buf_0 = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 0) + buf_1 = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 1) + buf_2 = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 2) + a = UOp(Ops.LOAD, dtype, (buf_1, ShapeTracker.from_shape((32, 1)).to_uop())) + b = UOp(Ops.LOAD, dtype, (buf_2, ShapeTracker.from_shape((32, 1)).to_uop())) + store = UOp(Ops.STORE, dtypes.void, (buf_0, ShapeTracker.from_shape((32, 1)).to_uop(), a+b)) helper_test_verify_ast(store) def test_exactly_one_full_shape(self): dtype = dtypes.int - bufs = [UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), i) for i in range(6)] - a = UOp(UOps.LOAD, dtype, (bufs[2], ShapeTracker.from_shape((32, 1)).to_uop())) - b = UOp(UOps.LOAD, dtype, (bufs[3], ShapeTracker.from_shape((32, 1)).to_uop())) + bufs = [UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), i) for i in range(6)] + a = UOp(Ops.LOAD, dtype, (bufs[2], ShapeTracker.from_shape((32, 1)).to_uop())) + b = UOp(Ops.LOAD, dtype, (bufs[3], ShapeTracker.from_shape((32, 1)).to_uop())) st0 = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), a+b) - a = UOp(UOps.LOAD, dtype, (bufs[4], ShapeTracker.from_shape((32, 32)).to_uop())) - b = UOp(UOps.LOAD, dtype, (bufs[5], ShapeTracker.from_shape((32, 32)).to_uop())) + a = UOp(Ops.LOAD, dtype, (bufs[4], ShapeTracker.from_shape((32, 32)).to_uop())) + b = UOp(Ops.LOAD, dtype, (bufs[5], ShapeTracker.from_shape((32, 32)).to_uop())) st1 = UOp.store(bufs[1], ShapeTracker.from_shape((32, 32)).to_uop(), a+b) with self.assertRaises(InvalidASTException): helper_test_verify_ast(st0, st1) def test_no_implicit_broadcasting(self): - bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)] - a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop())) - b = a + UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.MAX, (1,))) - st = UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b)) + bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)] + a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop())) + b = a + UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.MAX, (1,))) + st = UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b)) with self.assertRaises(InvalidASTException): helper_test_verify_ast(st) def test_shrink_ok(self): - bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)] - a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)).to_uop())) - b = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(0, 1), offset=0, mask=None, contiguous=False),)).to_uop())) + bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)] + a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)).to_uop())) + b = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(0, 1), offset=0, mask=None, contiguous=False),)).to_uop())) st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), a+b) helper_test_verify_ast(st) def test_reduce_store(self): - bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)] - a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop())) - r = UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,))) + bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)] + a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop())) + r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,))) st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r) with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st) def test_reduce_add_store(self): - bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)] - a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop())) - r = UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,))) + bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)] + a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop())) + r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,))) st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r+a) with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st) def test_buffer_uops_st(self): a = Tensor.randn(4, 4)+2 uop_sts = verify_ast(a.schedule()[-1].ast) - store_st = [st for u,st in uop_sts.items() if u.op is UOps.STORE][0] + store_st = [st for u,st in uop_sts.items() if u.op is Ops.STORE][0] self.assertEqual(store_st, ShapeTracker.from_shape((4, 4))) - const_st = [st for u,st in uop_sts.items() if u.op is UOps.VALID][0] + const_st = [st for u,st in uop_sts.items() if u.op is Ops.VALID][0] self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4))) if __name__ == '__main__': diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index ee0d57e3be..1cfb5a9345 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -5,7 +5,7 @@ from collections import defaultdict from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callable, Sequence from enum import Enum, auto -from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp, UOps, PatternMatcher, print_uops, type_verify, resolve, \ +from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp, Ops, PatternMatcher, print_uops, type_verify, resolve, \ graph_rewrite, track_rewrites, Variable, sint from tinygrad.device import Device from tinygrad.renderer import Renderer, TensorCore, Program @@ -54,7 +54,7 @@ class TensorCoreOptions: class Kernel: def __init__(self, ast:UOp, opts:Optional[Renderer]=None): - if ast.op is UOps.SINK: self.ast = ast + if ast.op is Ops.SINK: self.ast = ast self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer try: uop_sts_map = verify_ast(self.ast) @@ -65,7 +65,7 @@ class Kernel: @functools.lru_cache(None) def ordered_parents(op:UOp) -> List[UOp]: return dedup([item for x in op.src for item in ordered_parents(x)] + [op]) - self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is UOps.REDUCE_AXIS]) + self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is Ops.REDUCE_AXIS]) self.vars: List[Variable] = self.ast.variables() self.bufs: List[UOp] = [x for x in self.ast.parents if x.op in BUFFER_UOPS] @@ -125,7 +125,7 @@ class Kernel: return ret @property - def membufs(self) -> List[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {UOps.LOAD, UOps.STORE}]) + def membufs(self) -> List[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}]) # TODO: these need more tests or it might silently be no-op def float4_axis(self, i:int): return [x-self.first_upcast for x in self.sts[i].unit_stride_axes() if x >= self.first_upcast and self.sts[i].shape[x]%4 == 0] # noqa: E501 @@ -273,16 +273,16 @@ class Kernel: def _create_tc_opts(self, reduceop:UOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]: has_cast = tc.dtype_in != tc.dtype_out - if has_cast and not (reduceop.src[0].op is UOps.CAST and reduceop.src[0].dtype == tc.dtype_out): return None + if has_cast and not (reduceop.src[0].op is Ops.CAST and reduceop.src[0].dtype == tc.dtype_out): return None mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0] if mul_op.arg is not BinaryOps.MUL: return None def buf_index(src:UOp) -> Optional[int]: # TODO: apply tc even if the sources are not from LOAD - if src.op is UOps.LOAD and src.dtype == tc.dtype_in: return self.bufs.index(src) + if src.op is Ops.LOAD and src.dtype == tc.dtype_in: return self.bufs.index(src) try: - if opt_level >= 1 and src.op is UOps.CAST and src.dtype == tc.dtype_in: return self.bufs.index(src.src[0]) + if opt_level >= 1 and src.op is Ops.CAST and src.dtype == tc.dtype_in: return self.bufs.index(src.src[0]) except ValueError: return None return None if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None @@ -442,7 +442,7 @@ class Kernel: check(axis < self.first_upcast, "cannot pad upcasted") # ok to pad SUM if all parent ALU ops have f(0) = 0 if (r:=self.reduceop) is not None and self.first_reduce <= axis: - check(r.arg[0] is BinaryOps.ADD and not any(u.op is UOps.ALU and u.arg in UNSAFE_PAD_OPS for u in r.parents), "cannot pad UNSAFE_PAD_OPS") + check(r.arg[0] is BinaryOps.ADD and not any(u.op is Ops.ALU and u.arg in UNSAFE_PAD_OPS for u in r.parents), "cannot pad UNSAFE_PAD_OPS") padded = False for i,st in enumerate(self.sts): if (s:=st.shape[axis]) == 1: continue # reduced @@ -472,7 +472,7 @@ class Kernel: MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4) if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \ self.reduceop is not None and self.reduceop.arg[0] is BinaryOps.ADD and len(self.full_shape) >= 2 and self.opts.has_shared and \ - (mulop:=self.reduceop.src[0]).arg is BinaryOps.MUL and mulop.src[0].op is UOps.LOAD and mulop.src[1].op is UOps.LOAD: + (mulop:=self.reduceop.src[0]).arg is BinaryOps.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD: st0, st1 = self.sts[self.bufs.index(mulop.src[0])], self.sts[self.bufs.index(mulop.src[1])] strides0, strides1 = st0.real_strides(), st1.real_strides() def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides)) @@ -615,20 +615,20 @@ class Kernel: arg = op.arg if op.op in BUFFER_UOPS: # for locals, we use the ShapeTracker that's in the srcs - st = op.st_arg if op.src[0].op is UOps.DEFINE_LOCAL else self.sts[self.bufs.index(op)] + st = op.st_arg if op.src[0].op is Ops.DEFINE_LOCAL else self.sts[self.bufs.index(op)] st_uop = (st if apply_to_st is None else apply_to_st(st)).to_uop() - if op.op is UOps.VALID: return op.replace(src=(st_uop,)) - if op.op is UOps.STORE: return op.replace(src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st))) + if op.op is Ops.VALID: return op.replace(src=(st_uop,)) + if op.op is Ops.STORE: return op.replace(src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st))) return op.replace(src=(op.src[0], st_uop, *[fixup_ast(x, apply_to_st) for x in op.src[2:]])) - if op.op is UOps.REDUCE_AXIS: + if op.op is Ops.REDUCE_AXIS: reduce_idx = len(self.bufs) + self.reduceops.index(op)*2 alu_op: BinaryOps = op.arg[0] axis = tuple(i for i in range(self.first_reduce+self.group_for_reduces, self.shape_len) if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i])) if op in self.bufs_for_tensor_core and (tc := self.tensor_core): rsrc = op.src[0] - if rsrc.op is UOps.CAST: rsrc = rsrc.src[0] - assert rsrc.op is UOps.ALU and rsrc.arg is BinaryOps.MUL + if rsrc.op is Ops.CAST: rsrc = rsrc.src[0] + assert rsrc.op is Ops.ALU and rsrc.arg is BinaryOps.MUL def fix_st(warp_dims, tcd_dims, tcd_expand, pattern_1, pattern_2, st1): wd, tcd = self.global_dims, self.first_upcast @@ -654,30 +654,30 @@ class Kernel: for i,s in enumerate(self.full_shape)) srcs = [] for i,(src,fix_st_fxn) in enumerate(zip(rsrc.src, [fix_st1, fix_st2])): - st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD] + st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is Ops.LOAD] local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape)) st_uop = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uop() - membuf = UOp(UOps.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (), (f"temp{-(-1-i)}", st_uop.arg.real_size())) - local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, st_uop, src)), fix_st_fxn) - srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, st_uop, local_store))) + membuf = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (), (f"temp{-(-1-i)}", st_uop.arg.real_size())) + local_store = fixup_ast(UOp(Ops.STORE, tc.dtype_in, (membuf, st_uop, src)), fix_st_fxn) + srcs.append(UOp(Ops.LOAD, tc.dtype_in, (membuf, st_uop, local_store))) else: # for TC=2, we can't do the shapetracker fixup srcs = [fixup_ast(rsrc.src[0]), fixup_ast(rsrc.src[1])] # MUL/SUM instead of WMMA - ret = UOp(UOps.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(BinaryOps.MUL, srcs[1]).cast(tc.dtype_out),), (alu_op, wmma_arg[-1])) + ret = UOp(Ops.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(BinaryOps.MUL, srcs[1]).cast(tc.dtype_out),), (alu_op, wmma_arg[-1])) else: # real WMMA, use CONTRACT/EXPAND to get the vectorization right wmma_upcast_axes = wmma_arg[-2] wmma_sz = [prod(x[1] for x in l) for l in wmma_upcast_axes] - wmma = UOp(UOps.WMMA, dtype=tc.dtype_out.vec(wmma_sz[2]), src=( - UOp(UOps.CONTRACT, dtype=rsrc.src[0].dtype.vec(wmma_sz[0]), src=(fixup_ast(rsrc.src[0], fix_st1),), arg=wmma_upcast_axes[0]), - UOp(UOps.CONTRACT, dtype=rsrc.src[1].dtype.vec(wmma_sz[1]), src=(fixup_ast(rsrc.src[1], fix_st2),), arg=wmma_upcast_axes[1]), + wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(wmma_sz[2]), src=( + UOp(Ops.CONTRACT, dtype=rsrc.src[0].dtype.vec(wmma_sz[0]), src=(fixup_ast(rsrc.src[0], fix_st1),), arg=wmma_upcast_axes[0]), + UOp(Ops.CONTRACT, dtype=rsrc.src[1].dtype.vec(wmma_sz[1]), src=(fixup_ast(rsrc.src[1], fix_st2),), arg=wmma_upcast_axes[1]), UOp.const(tc.dtype_out.vec(wmma_sz[2]), 0.0)), arg=wmma_arg) - ret = UOp(UOps.EXPAND, tc.dtype_out, (wmma,), arg=wmma_upcast_axes[2]) + ret = UOp(Ops.EXPAND, tc.dtype_out, (wmma,), arg=wmma_upcast_axes[2]) new_reduce_axes = tuple(i for i in axis if i-self.first_upcast not in [ax for ax, _ in tc.reduce_axes]) return op.replace(src=(ret,), arg=(alu_op, new_reduce_axes)) if new_reduce_axes else ret if self.group_for_reduces: - start = UOp(UOps.REDUCE_AXIS, op.dtype, (fixup_ast(op.src[0], apply_to_st),), arg=(alu_op, axis)) + start = UOp(Ops.REDUCE_AXIS, op.dtype, (fixup_ast(op.src[0], apply_to_st),), arg=(alu_op, axis)) second_axis = tuple(i for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces) \ if self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i]) # NOTE: if there's a grouped reduce, but no reduce axes for this reduce, we can skip it @@ -687,14 +687,14 @@ class Kernel: for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \ (1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)]) st_uop = ShapeTracker.from_shape(local_shape).to_uop() - local_buffer = UOp(UOps.DEFINE_LOCAL, op.dtype.ptr(local=True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size())) - local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, start))) - grouped_reduce = UOp(UOps.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], second_axis)) + local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local=True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size())) + local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, start))) + grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], second_axis)) if op is self.reduceops[-1]: return grouped_reduce st_uop = ShapeTracker.from_shape(tuple([1 if i in second_axis else a for i,a in enumerate(local_shape)])).to_uop() - return UOp(UOps.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, grouped_reduce))) + return UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, grouped_reduce))) arg = (alu_op, axis) - elif op.op is UOps.SINK: + elif op.op is Ops.SINK: arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals) return op.replace(src=tuple(fixup_ast(x, apply_to_st) for x in op.src), arg=arg) # NOTE: rewrite with an empty PatternMatcher to dedup UOps @@ -728,7 +728,7 @@ class Kernel: # group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes # TODO: these max and min don't work on symbolic, and results are very wrong. mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group) - for _, group in itertools.groupby([x for x in self.ast.parents if x.op in BUFFER_UOPS and x.src[0].op is UOps.DEFINE_GLOBAL], + for _, group in itertools.groupby([x for x in self.ast.parents if x.op in BUFFER_UOPS and x.src[0].op is Ops.DEFINE_GLOBAL], key=lambda x: (x.op, x.src[0].arg))) return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes, global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None) @@ -738,15 +738,15 @@ class Kernel: def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None: if not uop.has_st or uop in sts: return # restore globals from the two stage reduce - if uop.op is UOps.LOAD and uop.src[0].op is UOps.DEFINE_LOCAL: + if uop.op is Ops.LOAD and uop.src[0].op is Ops.DEFINE_LOCAL: _assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts) sts[uop] = sts[local_reduce] return for x in uop.src: _assert_valid_uop(x, st, sts) # only reduceuop is allowed to change shape, limited to turning n to 1 - if uop.op in {UOps.REDUCE_AXIS, UOps.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg)) + if uop.op in {Ops.REDUCE_AXIS, Ops.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg)) # movementops are pushed to VIEW - elif uop.op is UOps.VIEW: st = uop.arg + elif uop.op is Ops.VIEW: st = uop.arg # everything else inherits shape else: st = (src_sts:=[sts[x] for x in uop.src if x.has_st])[0] @@ -756,7 +756,7 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> sts[uop] = st def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]: - assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK" + assert ast.op is Ops.SINK and all(x.op is Ops.STORE for x in ast.src), "must be SINK" assert all_same([x.st_arg.size for x in ast.src]), "outputs must be exactly the same size" sts: Dict[UOp, ShapeTracker] = {} for out in ast.src: _assert_valid_uop(out, out.st_arg, sts) diff --git a/tinygrad/codegen/linearize.py b/tinygrad/codegen/linearize.py index b7c4d26ea5..d213e22ff7 100644 --- a/tinygrad/codegen/linearize.py +++ b/tinygrad/codegen/linearize.py @@ -1,6 +1,6 @@ from typing import List, Set, Dict, Tuple import functools, heapq -from tinygrad.ops import type_verify, END_FOR_UOP, UOp, UOps +from tinygrad.ops import type_verify, END_FOR_UOP, UOp, Ops from tinygrad.dtype import dtypes from tinygrad.helpers import DEBUG @@ -10,13 +10,13 @@ def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], srcs:Dict[UOp, Dict[U children[u] = [] for x in u.src: srcs[u].update(get_children_dfs(x, children, srcs, in_degree)) - if x.op is UOps.RANGE and x.arg[1]: srcs[u][x] = None + if x.op is Ops.RANGE and x.arg[1]: srcs[u][x] = None children[x].append(u) in_degree[u] = len(u.src) return srcs[u] def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: - assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}" + assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}" # filter nodes that don't link to a sink # BFS toposort children: Dict[UOp, List[UOp]] = {} @@ -25,38 +25,38 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: get_children_dfs(sink, children, range_srcs, in_degree) @functools.lru_cache(None) - def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]: - if x.op is UOps.SINK: return set() + def get_recursive_children(x:UOp, end:Ops, include_self=False) -> Set[UOp]: + if x.op is Ops.SINK: return set() return set.union({x} if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end])) # scope children impact the toposort and END* insertion scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP} - range_phi = {r:[p for p in scope_children[r] if p.op is UOps.ASSIGN] for r in scope_children if r.op is UOps.RANGE} + range_phi = {r:[p for p in scope_children[r] if p.op is Ops.ASSIGN] for r in scope_children if r.op is Ops.RANGE} # assign priorities def get_priority(u:UOp): priority = 0 # prefer ranges that depend on the least number of independent ranges - if u.op is UOps.RANGE and u.arg[1]: + if u.op is Ops.RANGE and u.arg[1]: priority += u.arg[0] for p in range_phi[u]: priority += 10000*len([r for r in range_srcs[p] if not any(i in range_phi[u] for i in range_phi[r])]) - elif u.op is UOps.CONST: + elif u.op is Ops.CONST: # place consts first here, they don't do anything and it can cause issues with DEFINE_ACC priority -= 100000000000 else: # prefer uops that are loop children - priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is UOps.RANGE and u in ss]) - if u.op is UOps.IF and len(u.src) == 1: priority += 10000000 # if penalty + priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is Ops.RANGE and u in ss]) + if u.op is Ops.IF and len(u.src) == 1: priority += 10000000 # if penalty return priority priorities:Dict[UOp, int] = {u:get_priority(u) for u in children} # prevent priority inversion @functools.lru_cache(None) def fix_priority(u:UOp, lowest_priority): - if u.op in {UOps.CAST, UOps.BITCAST, UOps.ALU, UOps.VECTORIZE, UOps.GEP, UOps.SPECIAL, UOps.DEFINE_LOCAL, UOps.LOAD}: + if u.op in {Ops.CAST, Ops.BITCAST, Ops.ALU, Ops.VECTORIZE, Ops.GEP, Ops.SPECIAL, Ops.DEFINE_LOCAL, Ops.LOAD}: priorities[u] = min(priorities[u], lowest_priority) - if u.op is UOps.LOAD: priorities[u] += 100 # load penalty (here) + if u.op is Ops.LOAD: priorities[u] += 100 # load penalty (here) for x in u.src: fix_priority(x, priorities[u]) fix_priority(sink, 0) @@ -73,8 +73,8 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: p,_,x = heapq.heappop(queue) if DEBUG >= 7: print(f"{p:5d}", x.op, x.dtype, x.arg) if x in scope_children: scope_end[x] = x - if x.op is UOps.DEFINE_ACC: - idx = min([_uops.index(l) for l in x.src if l.op is UOps.RANGE]) + if x.op is Ops.DEFINE_ACC: + idx = min([_uops.index(l) for l in x.src if l.op is Ops.RANGE]) _uops.insert(idx, x) else: _uops.append(x) for u, ss in scope_children.items(): diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 995077107f..1c2562f189 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -6,7 +6,7 @@ from typing import List, Tuple, cast, Optional from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import variable_to_uop from tinygrad.dtype import dtypes -from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat, sint, identity_element +from tinygrad.ops import KernelInfo, BinaryOps, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element from tinygrad.renderer import Renderer from tinygrad.helpers import all_int, prod, partition, flatten @@ -33,7 +33,7 @@ def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]): def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse=False) -> List[UOp]: if reverse: dims = dims[::-1] limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims - ret = raw_idxs = [UOp(UOps.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)] + ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)] if limited != dims: ret = [] # cast for mypy, get_contraction won't be None @@ -57,8 +57,8 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext: first_upcasted = len(full_shape)-ki.upcasted first_output_st: ShapeTracker = ast.src[0].st_arg # if there's no reduce, this is first_upcasted. assumes reduces are at the end - first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.sparents if x.op is UOps.REDUCE_AXIS)) - local_loads = [x for x in ast.parents if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL] + first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.sparents if x.op is Ops.REDUCE_AXIS)) + local_loads = [x for x in ast.parents if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL] # NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces group_for_reduces = sum([any(j!=y for j in x) for x,y in zip( [[l.st_arg.shape[i] for l in local_loads] for i in range(first_reduce,first_upcasted)], @@ -75,22 +75,22 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext: get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max) else: # all loops are RANGES - idxs = [UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, False)) + idxs = [UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, False)) for i,g in enumerate(full_shape[:first_reduce])] # reduce loops - idxs += [UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, True)) + idxs += [UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, True)) for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)] # upcast loops for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted): assert isinstance(g, int), "needs to be int to upcast/unroll" - idxs.append(UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),))) + idxs.append(UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),))) # late indexes (group for reduce) ridxs = idxs[:] for a in range(first_reduce, first_reduce+group_for_reduces): - ridxs[a] = UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(full_shape[a])), (1000+a, True)) + ridxs[a] = UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(full_shape[a])), (1000+a, True)) return IndexContext(idxs, ridxs) @@ -98,50 +98,50 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext: def lower_reduce_axis(ctx: IndexContext, x: UOp): # NOTE: always using ridxs is fine here - reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is UOps.RANGE) - assert all(x.op is UOps.EXPAND for x in reduce_expand), f"not all EXPANDS in {reduce_expand} for {x.axis_arg}" + reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE) + assert all(x.op is Ops.EXPAND for x in reduce_expand), f"not all EXPANDS in {reduce_expand} for {x.axis_arg}" alu_op: BinaryOps = x.arg[0] ret = x.src[0] if len(contract_axis:=flatten(x.arg for x in reduce_expand)): - ret = UOp(UOps.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis)) + ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis)) ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(ret.dtype.count)]) - return UOp(UOps.REDUCE, x.dtype, (ret,) + tuple(reduce_range), alu_op) if len(reduce_range) else ret + return UOp(Ops.REDUCE, x.dtype, (ret,) + tuple(reduce_range), alu_op) if len(reduce_range) else ret def lower_load_store(ctx: IndexContext, x: UOp): - idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else ctx.idxs) + idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs) # TODO: check has_valid in UPat, not here - has_valid = valid.op is not UOps.CONST or valid.arg is not True + has_valid = valid.op is not Ops.CONST or valid.arg is not True buf = x.src[0] - if x.op is UOps.LOAD: - barrier = (UOp(UOps.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is UOps.DEFINE_LOCAL else () - return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((x.const_like(0), valid) if has_valid else ()) + barrier) + if x.op is Ops.LOAD: + barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else () + return UOp(Ops.LOAD, x.dtype, (buf, idx) + ((x.const_like(0), valid) if has_valid else ()) + barrier) # NOTE: only store the local reduceop in the threads that are actually doing the reduce - store_back = x.src[0].op is UOps.DEFINE_LOCAL and x.src[2].op is UOps.REDUCE and \ - x.src[2].src[0].op is UOps.LOAD and x.src[2].src[0].src[0].op is UOps.DEFINE_LOCAL + store_back = x.src[0].op is Ops.DEFINE_LOCAL and x.src[2].op is Ops.REDUCE and \ + x.src[2].src[0].op is Ops.LOAD and x.src[2].src[0].src[0].op is Ops.DEFINE_LOCAL # NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[2].src else u for u in ctx.idxs]) - if x.src[0].op is UOps.DEFINE_GLOBAL or store_back: + if x.src[0].op is Ops.DEFINE_GLOBAL or store_back: for oidx, ridx in zip(ctx.idxs, ctx.ridxs): if oidx is not ridx: valid = valid * oidx.eq(0) - has_valid = valid.op is not UOps.CONST or valid.arg is not True - return UOp(UOps.STORE, dtypes.void, (buf, idx, x.src[2]) + ((valid,) if has_valid else ())) + has_valid = valid.op is not Ops.CONST or valid.arg is not True + return UOp(Ops.STORE, dtypes.void, (buf, idx, x.src[2]) + ((valid,) if has_valid else ())) pm_lowerer = PatternMatcher([ - (UPat(UOps.REDUCE_AXIS, name="x"), lower_reduce_axis), - (UPat(UOps.VALID, src=(UPat(UOps.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]), + (UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis), + (UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]), # rewrite LOAD/STORE VIEW to LOAD/STORE with indexed - (UPat((UOps.LOAD, UOps.STORE), src=(UPat(), UPat(UOps.VIEW)), allow_any_len=True, name="x"), lower_load_store), + (UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store), ]) def do_reduce(ctx:List[int], root:UOp): - acc = UOp(UOps.DEFINE_ACC, root.dtype, + acc = UOp(Ops.DEFINE_ACC, root.dtype, (root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(root.src[1:]), (ctx[0],)) ctx[0] += 1 return acc.assign(acc.alu(root.arg, root.src[0])) just_reduce = PatternMatcher([ # do reduce - (UPat(UOps.REDUCE, name="root"), do_reduce), + (UPat(Ops.REDUCE, name="root"), do_reduce), ]) def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp: diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index afd899794f..6990dc2454 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -3,7 +3,7 @@ from typing import Optional, Tuple, Dict, List, TYPE_CHECKING, Any, DefaultDict, import functools, itertools, operator from collections import defaultdict from tinygrad.dtype import dtypes, ImageDType, PtrDType -from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat, symbolic_simple +from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES @@ -16,15 +16,15 @@ def fold_expanded(ex, buf): if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): return None new_srcs = dedup(list(ex.src)) old_new_srcs = new_srcs[:] - is_load, is_image = new_srcs[0].op is UOps.LOAD, isinstance(buf.dtype, ImageDType) + is_load, is_image = new_srcs[0].op is Ops.LOAD, isinstance(buf.dtype, ImageDType) # first, extract all the relevant offsets offsets_rootsrc: DefaultDict[Any, dict] = defaultdict(dict) for i,s in enumerate(new_srcs): idx = s.src[0].src[1] if s.dtype.count != 1 or (is_image and idx.dtype.count == 2): continue - if idx.arg is BinaryOps.ADD and idx.src[1].op is UOps.CONST: root_src, arg = idx.src[0], idx.src[1].arg - elif idx.op is UOps.CONST: root_src, arg = "CONST", idx.arg + if idx.arg is BinaryOps.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg + elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg else: root_src, arg = idx, 0 # add gates for gated if len(s.src[0].src) == 3: root_src = (s.src[0].src[2], root_src) @@ -45,20 +45,20 @@ def fold_expanded(ex, buf): if is_image: # for images, we rewrite the index. it must evenly divide 4 from the above check new_src[0] = buf.index( - UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))), + UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))), rootsrc[0] if isinstance(rootsrc, tuple) else None) else: # for non image, we upcast the index pointer new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(new_src[0].dtype.local)) # vectorize the store if not is_load: - new_src[1] = UOp(UOps.VECTORIZE, new_src[1].dtype.vec(fold_length), tuple(new_srcs[offsets[o+i]].src[1] for i in range(fold_length))) + new_src[1] = UOp(Ops.VECTORIZE, new_src[1].dtype.vec(fold_length), tuple(new_srcs[offsets[o+i]].src[1] for i in range(fold_length))) # generate the folded new_srcs if is_load: - new_load = UOp(UOps.LOAD, load_1.dtype.vec(fold_length), tuple(new_src)) + new_load = UOp(Ops.LOAD, load_1.dtype.vec(fold_length), tuple(new_src)) for i in range(fold_length): new_srcs[offsets[o+i]] = new_load.gep(i) else: - for i in range(fold_length): new_srcs[offsets[o+i]] = UOp(UOps.STORE, dtypes.void, tuple(new_src)) if i == 0 else None + for i in range(fold_length): new_srcs[offsets[o+i]] = UOp(Ops.STORE, dtypes.void, tuple(new_src)) if i == 0 else None for i in range(fold_length): used.add((rootsrc,o+i)) # dedup expand for LOAD @@ -72,15 +72,15 @@ def fix_unfoldable_image_load(load:UOp, buf:UOp): new_src = list(load.src) # TODO: copied logic from above new_src[0] = load.src[0].src[0].index( - UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))), + UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))), load.src[0].src[2] if len(load.src[0].src) == 3 else None) - vec_load = UOp(UOps.LOAD, load.dtype.vec(4), tuple(new_src)) + vec_load = UOp(Ops.LOAD, load.dtype.vec(4), tuple(new_src)) return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), load.const_like(float('nan'))) -buf_idx_pat = UPat(UOps.INDEX, src=(UPat.var("buf"),), allow_any_len=True) +buf_idx_pat = UPat(Ops.INDEX, src=(UPat.var("buf"),), allow_any_len=True) float4_folding = PatternMatcher([ - (UPat(UOps.VECTORIZE, src=UPat(UOps.LOAD, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded), - (UPat((UOps.BARRIER, UOps.SINK), src=UPat(UOps.STORE, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded), + (UPat(Ops.VECTORIZE, src=UPat(Ops.LOAD, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded), + (UPat((Ops.BARRIER, Ops.SINK), src=UPat(Ops.STORE, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded), ]) # ***** image load valid simplification ***** @@ -124,18 +124,18 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]: powers_of_two = {2**i:i for i in range(64)} @functools.lru_cache(None) def get_late_rewrite_patterns(ops, force_transcendental=False): - pat: List[Tuple[UPat, Callable]] = [(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=op), f) for op,f in \ + pat: List[Tuple[UPat, Callable]] = [(UPat(Ops.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=op), f) for op,f in \ ((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if op not in ops or force_transcendental] # rewrite MOD to AND (which should always be supported, but not for generic in tests) if BinaryOps.AND in ops: - pat += [(UPat(UOps.ALU, arg=BinaryOps.MOD, src=(UPat.var('base'), UPat.cvar("const"))), + pat += [(UPat(Ops.ALU, arg=BinaryOps.MOD, src=(UPat.var('base'), UPat.cvar("const"))), lambda base,const: base & (const.arg-1) if const.arg in powers_of_two else None)] # rewrite MUL/IDIV to SHL+SHR if BinaryOps.SHL in ops and BinaryOps.SHR in ops: pat += [ - (UPat(UOps.ALU, arg=BinaryOps.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const: + (UPat(Ops.ALU, arg=BinaryOps.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const: mul << powers_of_two[const.arg] if const.arg in powers_of_two else None), # (x * (2**y)) -> shl(x,y) - (UPat(UOps.ALU, arg=BinaryOps.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const: + (UPat(Ops.ALU, arg=BinaryOps.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const: div >> powers_of_two[const.arg] if const.arg in powers_of_two else None)] # (x // (2**y)) -> shr(x,y) if UnaryOps.NEG in ops: pat += [(UPat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))] @@ -175,8 +175,8 @@ def loop_collapse(compval, multconst, rng:UOp, acc:UOp, idx2=None,idx3=None,extr if vec is not None: # add, mul, loop_start, loop_end def dvec(x:UOp): - if x.op is UOps.CONST: return UOp.const(x.dtype.vec(vec.dtype.count), x.arg) - return UOp(UOps.VECTORIZE, x.dtype.vec(vec.dtype.count), src=(x,)*vec.dtype.count) + if x.op is Ops.CONST: return UOp.const(x.dtype.vec(vec.dtype.count), x.arg) + return UOp(Ops.VECTORIZE, x.dtype.vec(vec.dtype.count), src=(x,)*vec.dtype.count) add, mul, loop_start, loop_end = dvec(add), dvec(mul), dvec(loop_start), dvec(loop_end) if mul.vmin > 0 and ne is not None: comprange = UOp.minimum(loop_end, UOp.maximum((add-compval)//mul + (loop_end-loop_start), loop_start)) @@ -209,7 +209,7 @@ def gep_through_wmma(gep:UOp, wmma:UOp): ssz = prod(x[1] for x in sz) for w in wmma_idxs: src_args += list(range((w//out_sz)*ssz, (w//out_sz)*ssz + ssz)) tsrcs.append(s.gep(tuple(src_args))) - return UOp(UOps.WMMA, gep.dtype, tuple(tsrcs), wmma.arg) + return UOp(Ops.WMMA, gep.dtype, tuple(tsrcs), wmma.arg) def no_vectorized_wmma(wmma:UOp): out_sz = prod(x[1] for x in wmma.arg[6][-1]) @@ -218,9 +218,9 @@ def no_vectorized_wmma(wmma:UOp): for s,sz in zip(wmma.src, wmma.arg[6]): ssz = prod(x[1] for x in sz) tsrcs.append([s.gep(tuple(range(grp, grp+ssz))) for grp in range(0, s.dtype.count, ssz)]) - wmmas = [UOp(UOps.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)] + wmmas = [UOp(Ops.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)] wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas]) - return UOp(UOps.VECTORIZE, wmma.dtype, tuple(wmma_ex)) + return UOp(Ops.VECTORIZE, wmma.dtype, tuple(wmma_ex)) def reduce_collapse(acc:UOp, ret:UOp, alu:UOp): reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.sparents) @@ -231,81 +231,81 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp): for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count) return ret -acc_pat, rng_pat = UPat(UOps.DEFINE_ACC, name="acc"), UPat(UOps.RANGE, name="rng") +acc_pat, rng_pat = UPat(Ops.DEFINE_ACC, name="acc"), UPat(Ops.RANGE, name="rng") rng_aug = UPat.any(rng_pat, UPat.var("add")+rng_pat, UPat.var("mul")*rng_pat, UPat.var("add")+UPat.var("mul")*rng_pat) index_load = UPat.var("buf").index(rng_aug).load(name="ld") -arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(UOps.VECTORIZE, name="vec", src=rng_aug)) -arange_m = arange_augrng.lt(UPat.cvar("compval")).ne(UPat(UOps.CONST, name="ne", arg=True)).where(UPat.cvar("multconst"), UPat.const(None, 0)) +arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug)) +arange_m = arange_augrng.lt(UPat.cvar("compval")).ne(UPat(Ops.CONST, name="ne", arg=True)).where(UPat.cvar("multconst"), UPat.const(None, 0)) # this is symbolic 2.0 sym = symbolic_flat+PatternMatcher([ # self ASSIGN is just self - (UPat(UOps.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x), + (UPat(Ops.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x), # ASSIGN to global is just self - (UPat(UOps.ASSIGN, src=(UPat(UOps.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x), + (UPat(Ops.ASSIGN, src=(UPat(Ops.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x), # VECTORIZE/CONST, VECTORIZE/GEP - (UPat(UOps.VECTORIZE, src=UPat(UOps.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))), - (UPat(UOps.VECTORIZE, src=UPat(UOps.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))), + (UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))), + (UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))), # reorder ALU/VECTORIZE - (UPat(UOps.ALU, src=(UPat(UOps.VECTORIZE, src=UPat(name='x')), UPat(UOps.VECTORIZE, src=UPat(name='y'))), name='alu'), - lambda x,y,alu: UOp(UOps.VECTORIZE, alu.dtype, (UOp(UOps.ALU, alu.dtype.scalar(), (x,y), alu.arg),)*alu.dtype.count)), + (UPat(Ops.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'), + lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(Ops.ALU, alu.dtype.scalar(), (x,y), alu.arg),)*alu.dtype.count)), # VECTORIZE of a single element is just that element - (UPat(UOps.VECTORIZE, src=(UPat(name='x'),)), lambda x: x), + (UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x), # VECTORIZE void is SINK - (UPat(UOps.VECTORIZE, dtype=dtypes.void, src=UPat(UOps.BARRIER, name='b')), lambda b: b), - (UPat(UOps.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(UOps.SINK, dtypes.void, x.src)), + (UPat(Ops.VECTORIZE, dtype=dtypes.void, src=UPat(Ops.BARRIER, name='b')), lambda b: b), + (UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)), # GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST - (UPat(UOps.GEP, src=(UPat(UOps.GEP, name='g2'),), name='g1'), + (UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'), lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))), - (UPat(UOps.GEP, src=(UPat(UOps.VECTORIZE, name="vec"),), name="gep"), - lambda gep, vec: UOp(UOps.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]), - (UPat(UOps.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)), - (UPat(UOps.GEP, src=(UPat(UOps.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))), + (UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"), + lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]), + (UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)), + (UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))), # push all GEPs through ALUs (fix arange stuff) - (UPat(UOps.GEP, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST), name='alu'),), name='gep'), + (UPat(Ops.GEP, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'), lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg)), # push some GEPs through WMMAs - (UPat(UOps.GEP, src=(UPat(UOps.WMMA, name="wmma"),), name="gep"), gep_through_wmma), + (UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma), # tensor core with a 0 input is acc - (UPat(UOps.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc), - (UPat(UOps.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc), + (UPat(Ops.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc), + (UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc), # tensor core cleanups - (UPat.var("add") + UPat(UOps.WMMA, name="wmma"), + (UPat.var("add") + UPat(Ops.WMMA, name="wmma"), lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)), # threefry - (UPat(UOps.ALU, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key")), arg=BinaryOps.THREEFRY), threefry2x32), + (UPat(Ops.ALU, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key")), arg=BinaryOps.THREEFRY), threefry2x32), # arange loop folding (acc_pat.assign(UPat.any(arange_m, arange_m+UPat.var("extra"))+acc_pat), loop_collapse), # indexing, with cast or where - (acc_pat.assign(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse), - (acc_pat.assign(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse), + (acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse), + (acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse), # parentless reduce - (acc_pat.assign(UPat(UOps.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.ADD, name="alu")), reduce_collapse), - (acc_pat.assign(UPat(UOps.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.MAX, name="alu")), reduce_collapse), + (acc_pat.assign(UPat(Ops.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.ADD, name="alu")), reduce_collapse), + (acc_pat.assign(UPat(Ops.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.MAX, name="alu")), reduce_collapse), # ** self folding ** - (UPat(UOps.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST - (UPat(UOps.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP + (UPat(Ops.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST + (UPat(Ops.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP # x!=0 -> (bool)x (UPat.var("x").ne(0), lambda x: x.cast(dtypes.bool.vec(x.dtype.count))), # ** load/store folding ** - (UPat.store(UPat(UOps.INDEX, name="index"), UPat.load(UPat(UOps.INDEX, name="index"))), lambda index: UOp(UOps.NOOP)), - (UPat.store(UPat(UOps.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(UOps.INDEX, name="index")))), + (UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)), + (UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(Ops.INDEX, name="index")))), lambda index, gate, alt: UOp.store(index.src[0].index(index.src[1], gate), alt)), # fold gated LOAD/STORE (UPat().index(UPat(), UPat.const(dtypes.bool, True)).named("idx"), lambda idx: idx.replace(src=idx.src[0:2])), # remove True (UPat().index(UPat(), UPat.const(dtypes.bool, False)).named("idx"), lambda idx: idx.const_like(0)), # False -> NULL pointer - (UPat(UOps.LOAD, src=(UPat.const(None, 0),), allow_any_len=True, name="x"), lambda x: x.const_like(0)), # NULL pointer load loads 0 - (UPat(UOps.STORE, src=(UPat.const(None, 0),), allow_any_len=True), lambda: UOp(UOps.NOOP)), # NULL pointer store does nothing + (UPat(Ops.LOAD, src=(UPat.const(None, 0),), allow_any_len=True, name="x"), lambda x: x.const_like(0)), # NULL pointer load loads 0 + (UPat(Ops.STORE, src=(UPat.const(None, 0),), allow_any_len=True), lambda: UOp(Ops.NOOP)), # NULL pointer store does nothing # remove NOOPs from SINK - (UPat(UOps.SINK, name="root"), - lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None), + (UPat(Ops.SINK, name="root"), + lambda root: UOp(Ops.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not Ops.NOOP)) != len(root.src) else None), # remove EXPANDs from SINK/BARRIER - (UPat(UOps.BARRIER, src=(UPat((UOps.VECTORIZE, UOps.SINK), name='sink'),)), lambda sink: UOp(UOps.BARRIER, dtypes.void, sink.src)), - (UPat(UOps.SINK, name="root"), - lambda root: UOp(UOps.SINK, root.dtype, tuple(flatten(x.src if x.op in {UOps.SINK, UOps.EXPAND} else (x,) for x in root.src)), root.arg) - if any(x.op in {UOps.SINK, UOps.EXPAND} for x in root.src) else None), + (UPat(Ops.BARRIER, src=(UPat((Ops.VECTORIZE, Ops.SINK), name='sink'),)), lambda sink: UOp(Ops.BARRIER, dtypes.void, sink.src)), + (UPat(Ops.SINK, name="root"), + lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.EXPAND} else (x,) for x in root.src)), root.arg) + if any(x.op in {Ops.SINK, Ops.EXPAND} for x in root.src) else None), ]) # *** uop expander *** @@ -325,10 +325,10 @@ def _swizzle_args(cargs:Tuple[Tuple[int, int], ...], eargs:Tuple[Tuple[int, int] return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)] def do_expand(root:UOp): - expands = [x for x in root.src if x.op is UOps.EXPAND] + expands = [x for x in root.src if x.op is Ops.EXPAND] if len(expands) == 0: return None # NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct? - exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is UOps.WMMA else () + exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is Ops.WMMA else () if all_same(expands_args:=[x.arg for x in expands]) and len(exclude_args) == 0: # if there's only one expand arg, it's okay to use it (optimization) expand_args = expands[0].arg @@ -338,8 +338,8 @@ def do_expand(root:UOp): expand_sz = prod([x[1] for x in expand_args]) new_srcs = [] for i,src in enumerate(root.src): - if src.op is UOps.EXPAND: - if root.op is UOps.IF and i == 0: + if src.op is Ops.EXPAND: + if root.op is Ops.IF and i == 0: # IF means OR on first arg to IF new_srcs.append(functools.reduce(operator.__or__, [src.src[0].gep(i) for i in range(expand_sz)])) elif expand_args == src.arg: @@ -352,70 +352,70 @@ def do_expand(root:UOp): new_srcs.append(src.src[0].gep(tuple(lst))) else: # non-EXPAND input - if (root.op is UOps.IF) or (root.op is UOps.REDUCE and i != 0): + if (root.op is Ops.IF) or (root.op is Ops.REDUCE and i != 0): # for the first arg of IF and the RANGE args of REDUCE, just pass them through ignoring EXPANDS new_srcs.append(src) elif src.dtype.count > 1: # put any input dtype > 1 grouped together - new_srcs.append(UOp(UOps.VECTORIZE, + new_srcs.append(UOp(Ops.VECTORIZE, src.dtype.scalar().vec(expand_sz*src.dtype.count), tuple(src.gep(i) for i in range(src.dtype.count))*expand_sz)) else: # repeat the arg new_srcs.append(src.broadcast(expand_sz)) new_arg = root.arg - if root.op is UOps.GEP: + if root.op is Ops.GEP: assert root.dtype.count == 1 # is this right? new_arg = tuple(range(root.arg[0], new_srcs[0].dtype.count, new_srcs[0].dtype.count // expand_sz)) nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg) - return UOp(UOps.EXPAND, root.dtype, (nsrc,), expand_args) + return UOp(Ops.EXPAND, root.dtype, (nsrc,), expand_args) def do_contract(con:UOp): ex = con.src[0] # CONTRACT without EXPAND repeats the element VECTORIZED - if ex.op is not UOps.EXPAND: return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count) + if ex.op is not Ops.EXPAND: return UOp(Ops.VECTORIZE, con.dtype, con.src*con.dtype.count) # CONTRACT may remove several axes from EXPAND assert con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong" idxs = [] for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)): idxs += [_expand_arg_to_idx(ex.arg, {**rpk, **lrpk}) for lrpk in _choices_from_args(con.arg)] - return UOp(UOps.EXPAND, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args) + return UOp(Ops.EXPAND, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args) def no_vectorized_alu(alu): if alu.dtype.vcount == 1: return None alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount)) - return UOp(UOps.VECTORIZE, alu.dtype, alus) + return UOp(Ops.VECTORIZE, alu.dtype, alus) def create_gate(root:UOp) -> Optional[UOp]: @functools.lru_cache(None) def _gate_srcs(u:UOp, gate:UOp) -> UOp: - if u.op is UOps.BARRIER: return u - if u.op is UOps.LOAD and u.src[-1].op is UOps.BARRIER: - return UOp(u.op, u.dtype, u.src[:-1]+(UOp(UOps.IF, dtypes.void, (gate, u.src[-1])),), u.arg) + if u.op is Ops.BARRIER: return u + if u.op is Ops.LOAD and u.src[-1].op is Ops.BARRIER: + return UOp(u.op, u.dtype, u.src[:-1]+(UOp(Ops.IF, dtypes.void, (gate, u.src[-1])),), u.arg) return u if (replace_source:=tuple(_gate_srcs(x, gate) for x in u.src)) == u.src else UOp(u.op, u.dtype, replace_source, u.arg) idx = root.src[0] - if idx.op is UOps.CAST: idx = idx.src[0] - return None if idx.op is not UOps.INDEX or len(idx.src) == 2 or (ret:=_gate_srcs(root, idx.src[2])) is root else ret + if idx.op is Ops.CAST: idx = idx.src[0] + return None if idx.op is not Ops.INDEX or len(idx.src) == 2 or (ret:=_gate_srcs(root, idx.src[2])) is root else ret expander = PatternMatcher([ # double expand - (UPat(UOps.EXPAND, name="outer", src=(UPat(UOps.EXPAND, name="inner"),)), - lambda outer, inner: UOp(UOps.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)), + (UPat(Ops.EXPAND, name="outer", src=(UPat(Ops.EXPAND, name="inner"),)), + lambda outer, inner: UOp(Ops.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)), # do expansion - (UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE, UOps.INDEX, UOps.ASSIGN, - UOps.VECTORIZE, UOps.REDUCE, UOps.IF), name="root", custom_early_reject=set([(UOps.EXPAND, None)])), do_expand), - (UPat(UOps.CONTRACT, name="con"), do_contract), + (UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN, + Ops.VECTORIZE, Ops.REDUCE, Ops.IF), name="root", custom_early_reject=set([(Ops.EXPAND, None)])), do_expand), + (UPat(Ops.CONTRACT, name="con"), do_contract), # vectorize DEFINE_ACC - (UPat(UOps.VECTORIZE, src=UPat(UOps.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype)), + (UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype)), # BARRIERs aren't actually expanded - (UPat(UOps.BARRIER, src=(UPat(UOps.EXPAND, name="ex"),)), - lambda ex: UOp(UOps.EXPAND, dtypes.void, (UOp(UOps.BARRIER, dtypes.void, ex.src),)*len(ex.src), ex.arg)), + (UPat(Ops.BARRIER, src=(UPat(Ops.EXPAND, name="ex"),)), + lambda ex: UOp(Ops.EXPAND, dtypes.void, (UOp(Ops.BARRIER, dtypes.void, ex.src),)*len(ex.src), ex.arg)), # empty EXPAND is NOOP - (UPat(UOps.EXPAND, src=(UPat.var('x'),), arg=()), lambda x: x), + (UPat(Ops.EXPAND, src=(UPat.var('x'),), arg=()), lambda x: x), # EXPAND GEP (needed for WMMA, generalize this) -> vectorized ALU - (UPat(UOps.EXPAND, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))), - lambda ex,x,y: UOp(UOps.EXPAND, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)), + (UPat(Ops.EXPAND, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))), + lambda ex,x,y: UOp(Ops.EXPAND, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)), ]) def no_vectorized_load_store(ls:UOp): @@ -423,75 +423,75 @@ def no_vectorized_load_store(ls:UOp): assert isinstance(idx.dtype, PtrDType) if idx.dtype.v == 1: return None tv = [UOp(ls.op, ls.dtype.scalar(), tuple(j.gep(i) for j in ls.src)) for i in range(idx.dtype.v)] - return UOp(UOps.VECTORIZE, ls.dtype, tuple(tv)) + return UOp(Ops.VECTORIZE, ls.dtype, tuple(tv)) def no_vectorized_acc(acc:UOp): if acc.dtype.count == 1: return None alus = tuple(UOp(acc.op, acc.dtype.scalar(), tuple(s.gep(i) if j == 0 else s for j,s in enumerate(acc.src)), acc.arg+(i,)) for i in range(acc.dtype.count)) - return UOp(UOps.VECTORIZE, acc.dtype, alus) + return UOp(Ops.VECTORIZE, acc.dtype, alus) devectorize = PatternMatcher([ # no ALU on vectorized dtypes - (UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.INDEX), name="alu"), no_vectorized_alu), - (UPat(UOps.WMMA, name="wmma"), no_vectorized_wmma), - (UPat(UOps.DEFINE_ACC, name="acc"), no_vectorized_acc), - (UPat((UOps.LOAD, UOps.STORE), name="ls"), no_vectorized_load_store), + (UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu), + (UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma), + (UPat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc), + (UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store), ]) def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:Optional[UOp]=None) -> Optional[UOp]: - if store_gate not in [gate.src[0] for gate in val.sparents if gate.op is UOps.IF]: return None + if store_gate not in [gate.src[0] for gate in val.sparents if gate.op is Ops.IF]: return None # remove the gate from the index return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val) load_store_indexing = PatternMatcher([ # late fixup of unfoldable image loads - (UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load), + (UPat(Ops.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load), # simplify valid - (UPat(UOps.ALU, name="valid", arg=BinaryOps.AND), simplify_valid), + (UPat(Ops.ALU, name="valid", arg=BinaryOps.AND), simplify_valid), # image load valid idx simplification - (UPat(UOps.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load), + (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load), # delete_redundant_gates (after expand) - (UPat(UOps.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")), + (UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")), UPat.var("val"))), delete_redundant_gates), ]) def idx_load_store(x:UOp): idx = x.src[0].index(x.src[1], x.src[3] if len(x.src) > 3 else None) - v = x.dtype.count if x.op is UOps.LOAD else x.src[2].dtype.count + v = x.dtype.count if x.op is Ops.LOAD else x.src[2].dtype.count if v > 1 and not isinstance(x.src[0].dtype, ImageDType): idx = idx.cast(idx.dtype.base.vec(v).ptr(idx.dtype.local)) - post_mask = x.src[4:] if len(x.src) > 3 else (x.src[2:] if x.op is UOps.LOAD else x.src[3:]) - if x.op is UOps.LOAD: return UOp(x.op, x.dtype, (idx,)+post_mask, x.arg) + post_mask = x.src[4:] if len(x.src) > 3 else (x.src[2:] if x.op is Ops.LOAD else x.src[3:]) + if x.op is Ops.LOAD: return UOp(x.op, x.dtype, (idx,)+post_mask, x.arg) return UOp(x.op, x.dtype, (idx,x.src[2])+post_mask, x.arg) migrate_indexing = PatternMatcher([ # use indexing for LOAD/STORE - (UPat((UOps.LOAD, UOps.STORE), src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store), + (UPat((Ops.LOAD, Ops.STORE), src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store), # create gate MUST BE BEFORE expander - (UPat(UOps.STORE, name="root"), create_gate), + (UPat(Ops.STORE, name="root"), create_gate), ]) def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:Optional[UOp]=None) -> UOp: # this moves the mask from the indexing to the load/store op for rendering nidx = buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx) - return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is UOps.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:]) + return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is Ops.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:]) pm_render = PatternMatcher([ # for rendering, we use explicit VECTORIZE - (UPat(UOps.CONST, name='c'), - lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None), - (UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))), - (UPat(UOps.GEP, name='gep'), lambda gep: UOp(UOps.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None), - (UPat(UOps.VECTORIZE, src=(UPat(name='x'),)), lambda x: x), + (UPat(Ops.CONST, name='c'), + lambda c: UOp(Ops.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None), + (UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))), + (UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None), + (UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x), # move masks of loads/stores - (UPat((UOps.LOAD, UOps.STORE), src=(UPat.any(masked_index:=UPat(UOps.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask"))), + (UPat((Ops.LOAD, Ops.STORE), src=(UPat.any(masked_index:=UPat(Ops.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask"))), masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask), ]) # *** uop graph *** def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: - assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}" + assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}" supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else () extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([]) diff --git a/tinygrad/engine/memory.py b/tinygrad/engine/memory.py index a1c0106a57..3e1994b892 100644 --- a/tinygrad/engine/memory.py +++ b/tinygrad/engine/memory.py @@ -3,7 +3,7 @@ from collections import defaultdict from tinygrad.engine.schedule import ScheduleItem from tinygrad.device import Device, Buffer from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG -from tinygrad.ops import UOps +from tinygrad.ops import Ops # **************** memory planning **************** @@ -47,5 +47,5 @@ def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...] def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]: # Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs. assigned = _internal_memory_planner([si.bufs for si in schedule], - noopt_buffers={b for si in schedule if si.ast.op is not UOps.SINK for b in si.bufs}) + noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs}) return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.assign_preloads) for si in schedule] diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index ea6adc2794..8451da1ceb 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -2,7 +2,7 @@ from typing import List, Dict, Optional, cast, Generator, Tuple import time, pprint from dataclasses import dataclass, replace from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context, TRACEMETA -from tinygrad.ops import UOps, UOp, Variable, sym_infer, sint +from tinygrad.ops import Ops, UOp, Variable, sym_infer, sint from tinygrad.dtype import dtypes from tinygrad.device import Device, Buffer from tinygrad.renderer import Renderer, Program @@ -181,18 +181,18 @@ class ExecItem: return et def lower_schedule_item(si:ScheduleItem) -> ExecItem: - assert len(set(x.device for x in si.bufs)) == 1 or si.ast.op is UOps.COPY - if si.ast.op is UOps.SINK: + assert len(set(x.device for x in si.bufs)) == 1 or si.ast.op is Ops.COPY + if si.ast.op is Ops.SINK: runner = get_runner(si.outputs[0].device, si.ast) return ExecItem(runner, [si.bufs[x] for x in runner.p.globals], si.metadata) out, arg = si.outputs[0], si.ast.arg - if si.ast.op is UOps.COPY: + if si.ast.op is Ops.COPY: kernel_type = BufferCopy if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]: kernel_type = BufferXfer return ExecItem(kernel_type(arg, out.device, si.inputs[0].device), list(si.bufs)) - if si.ast.op is UOps.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs)) - if si.ast.op is UOps.BUFFER_VIEW: return ExecItem(ViewOp(out), list(si.bufs)) + if si.ast.op is Ops.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs)) + if si.ast.op is Ops.BUFFER_VIEW: return ExecItem(ViewOp(out), list(si.bufs)) raise RuntimeError(f"don't know how to lower {si.ast}") def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]: diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 72b757816f..70b3dcc700 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -2,7 +2,7 @@ import sys, atexit, functools, itertools from collections import defaultdict, deque from dataclasses import dataclass, field from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast -from tinygrad.ops import BUFFER_UOPS, MetaOps, ReduceOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint +from tinygrad.ops import BUFFER_UOPS, MetaOps, ReduceOps, UnaryOps, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint from tinygrad.helpers import DEBUG, Metadata, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap from tinygrad.dtype import ImageDType, dtypes from tinygrad.shape.shapetracker import ShapeTracker @@ -15,7 +15,7 @@ from tinygrad.device import Buffer sys.setrecursionlimit(10000) BUF_LIMIT = {"METAL":32} -METAOPS = {MetaOps.COPY:UOps.COPY, MetaOps.EMPTY:UOps.EMPTY, MetaOps.VIEW:UOps.BUFFER_VIEW} +METAOPS = {MetaOps.COPY:Ops.COPY, MetaOps.EMPTY:Ops.EMPTY, MetaOps.VIEW:Ops.BUFFER_VIEW} # **** ScheduleItem return type @@ -34,7 +34,7 @@ class ScheduleItem: """Read only buffers in the schedule.""" return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs) @functools.cached_property - def output_idxs(self) -> Tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is UOps.SINK else (0,) + def output_idxs(self) -> Tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,) # **** small wrapper for LazyBuffer -> UOp @@ -62,21 +62,21 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) -> # consts are always fused and generated if buf.op is MetaOps.CONST: if isinstance(val:=buf.arg, UOp): ctx.var_vals.update([val.unbind()]) - return UOp(UOps.VALID, dtypes.bool, (buf.st.to_uop(),)).where(v:=UOp.const(dtype, buf.arg), v.const_like(0)) + return UOp(Ops.VALID, dtypes.bool, (buf.st.to_uop(),)).where(v:=UOp.const(dtype, buf.arg), v.const_like(0)) # everything else has BUFFER - ubuf = ctx.buf_uops.setdefault(b:=buf.buffer, UOp(UOps.BUFFER, b.dtype.ptr(), (), (len(ctx.buf_uops), (b.device, b.size, b.dtype)))) + ubuf = ctx.buf_uops.setdefault(b:=buf.buffer, UOp(Ops.BUFFER, b.dtype.ptr(), (), (len(ctx.buf_uops), (b.device, b.size, b.dtype)))) # if the buffer is already realized we just load it - if buf.is_realized(): return UOp(UOps.PRELOAD, dtype, (ubuf, buf.st.to_uop())) + if buf.is_realized(): return UOp(Ops.PRELOAD, dtype, (ubuf, buf.st.to_uop())) # everything else needs sources src = tuple(to_uop(x, ctx, cache) for x in buf.srcs) if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg) - elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, src) - elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (ubuf, src[1]), buf.arg) + elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(Ops.CONTIGUOUS, dtype, src) + elif buf.op is MetaOps.ASSIGN: ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg) elif buf.op in METAOPS: ret = UOp(METAOPS[cast(MetaOps, buf.op)], buf.dtype, (ubuf, *src), buf.arg) - elif buf.op is UnaryOps.CAST: ret = UOp(UOps.CAST, dtype, src) - elif buf.op is UnaryOps.BITCAST: ret = UOp(UOps.BITCAST, dtype, src) - else: ret = UOp(UOps.ALU, dtype, src, buf.op) - cache[buf] = ret = UOp(UOps.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret))) + elif buf.op is UnaryOps.CAST: ret = UOp(Ops.CAST, dtype, src) + elif buf.op is UnaryOps.BITCAST: ret = UOp(Ops.BITCAST, dtype, src) + else: ret = UOp(Ops.ALU, dtype, src, buf.op) + cache[buf] = ret = UOp(Ops.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret))) if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata if buf.forced_realize: ctx.realizes[ubuf] = ubuf return ret @@ -87,7 +87,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) -> def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp: if (n:=cache.get(u)) is not None: return n - if u.op is UOps.VIEW: return u.replace(arg=apply_to_st(u.arg)) + if u.op is Ops.VIEW: return u.replace(arg=apply_to_st(u.arg)) if len(u.src) == 0 or (u.st is not None and u.st == apply_to_st(u.st)): return u cache[u] = ret = u.replace(src=tuple(st_fixup(x, apply_to_st, cache) for x in u.src)) return ret @@ -123,7 +123,7 @@ def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp: return swizzle.src[0].r(root.arg[0], new_axis).view(ShapeTracker.from_shape(output_shape)) def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]: - swizzles = [x for x in root.src if x.op is UOps.VIEW and len(x.src) != 0] + swizzles = [x for x in root.src if x.op is Ops.VIEW and len(x.src) != 0] if len(swizzles) == 0: return None swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles] assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}" @@ -131,34 +131,34 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]: fixup_cache: Dict[UOp, UOp] = {} new_srcs = [x.src[0] if x in swizzles else st_fixup(x, lambda st:st.reshape(new_input_shape), fixup_cache) for x in root.src] ret = UOp(root.op, root.dtype, tuple(new_srcs), root.arg) - return ret if ret.op is UOps.STORE else ret.view(ShapeTracker.from_shape(new_shape)) + return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape)) def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu" - assert not any(x.op is UOps.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time" + assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time" return first_reduce.src[0].r(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg) -merge_views = PatternMatcher([(UPat(UOps.VIEW, src=(UPat(UOps.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))]) +merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))]) # push VIEW to loads view_left = merge_views+PatternMatcher([ # view before ALU - (UPat(UOps.VIEW, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.CONTIGUOUS, *BUFFER_UOPS), name="e"),), name="v"), + (UPat(Ops.VIEW, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *BUFFER_UOPS), name="e"),), name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))), ]) # push VIEW to stores view_right = merge_views+PatternMatcher([ # ASSIGN can override st - (UPat(UOps.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(UOps.ASSIGN, name="a"))), + (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.ASSIGN, name="a"))), lambda a,b,st: UOp.store(b, (a.arg[0]+st.arg).to_uop(), a.replace(arg=())) if a.arg else None), # VIEW on a reduce creates a new VIEW - (UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r), + (UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r), # push a VIEW down to STORE, through a reduce (ONLY reshapes) - (UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce), + (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce), # push VIEW(s) down to STORE, through an elementwise op (ONLY reshapes) - (UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.CONTIGUOUS, UOps.STORE), name="root"), push_swizzle_down_through_elementwise), - (UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), + (UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise), + (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) # ** ScheduleItem context builder @@ -180,18 +180,18 @@ def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]: def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp: ctx.bufs.append(x) - return UOp(UOps.DEFINE_GLOBAL, x.dtype, (), len(ctx.bufs)-1) -append_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), _append_buf)]) + return UOp(Ops.DEFINE_GLOBAL, x.dtype, (), len(ctx.bufs)-1) +append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)]) def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp: if b in ctx.assigned: ctx.assign_preloads.append(b) - return x.replace(op=UOps.LOAD) + return x.replace(op=Ops.LOAD) to_si = PatternMatcher([ - (UPat(UOps.VIEW, name="x"), _append_st_vars), - (UPat(UOps.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload), - (UPat(UOps.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x), - (UPat(UOps.SINK, src=(UPat.store(UPat(), UPat(), UPat(tuple(METAOPS.values()), name="x")),)), lambda ctx,x: x), + (UPat(Ops.VIEW, name="x"), _append_st_vars), + (UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload), + (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x), + (UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat(tuple(METAOPS.values()), name="x")),)), lambda ctx,x: x), ]) # ** fusion @@ -206,15 +206,15 @@ def full_ast_rewrite(pre:UOp, var_vals:Dict[Variable, int], assigned:Set[UOp]) - # fuse and fold store -> loads sink = graph_rewrite(pre, lazy+multioutput if len(pre.src)>1 else lazy, {x.src[0]:x.src[2] for x in pre.src}) # assert cyclic dependency - for b,ops in itertools.groupby((x for x in sink.sparents if x.op in {UOps.PRELOAD,UOps.LOAD} and x.src[0] in assigned), key=lambda x:x.src[0]): + for b,ops in itertools.groupby((x for x in sink.sparents if x.op in {Ops.PRELOAD,Ops.LOAD} and x.src[0] in assigned), key=lambda x:x.src[0]): if not all_same([x.op for x in ops]): raise RuntimeError(f"cycle detected in kernel.\nhelp: use .contiguous() to break the part loading pre-assign {b} into a different kernel.") # do movementops sink = graph_rewrite(graph_rewrite(sink, view_left), view_right) # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine - if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is UOps.ASSIGN]) != 0: + if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is Ops.ASSIGN]) != 0: if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \ - and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is UOps.PRELOAD and x.src[0] in assign_targets): + and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is Ops.PRELOAD and x.src[0] in assign_targets): raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) # convert to AST @@ -232,13 +232,13 @@ if getenv("RUN_PROCESS_REPLAY"): def realize(ctx:Dict[UOp, UOp], b:UOp, load:UOp, store:UOp) -> UOp: ctx[b] = store - return UOp(UOps.LOAD, load.dtype, (b, load.st_arg.to_uop())) + return UOp(Ops.LOAD, load.dtype, (b, load.st_arg.to_uop())) def UPatLoadStore(to_store=UPat()): return UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), to_store, name="store"), name="load") do_realize = PatternMatcher([ # always realize meta ops - (UPatLoadStore(UPat((UOps.ASSIGN, UOps.CONTIGUOUS, *METAOPS.values()))), realize), - (UPat((UOps.COPY, UOps.BUFFER_VIEW), src=(UPat.var("u"), UPat.any(UPatLoadStore(), UPatLoadStore().view(name="v"))), name="root"), + (UPatLoadStore(UPat((Ops.ASSIGN, Ops.CONTIGUOUS, *METAOPS.values()))), realize), + (UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.var("u"), UPat.any(UPatLoadStore(), UPatLoadStore().view(name="v"))), name="root"), lambda ctx,root,u,v=None,**kwargs: root.replace(src=(u, realize(ctx,**kwargs) if v is None else realize(ctx,**kwargs).view(v.st))),) ]) break_sched = PatternMatcher([(UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),]) diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index b9ccd5185e..f2f5d4886c 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -2,7 +2,7 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable import itertools, functools, random, math, time, multiprocessing, traceback, signal from collections import defaultdict from dataclasses import replace -from tinygrad.ops import UOp, UOps, Variable, sym_infer +from tinygrad.ops import UOp, Ops, Variable, sym_infer from tinygrad.device import Device, Buffer, Compiler from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name from tinygrad.dtype import ImageDType, PtrDType @@ -88,7 +88,7 @@ def _ensure_buffer_alloc(bufs:List[Buffer]) -> List[Buffer]: return [buf.ensure_ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> List[Buffer]: bufsts: DefaultDict[int, List[UOp]] = defaultdict(list) for x in lin.bufs: - if x.src[0].op is UOps.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x) + if x.src[0].op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x) rawbufs: List[Optional[Buffer]] = [None]*len(bufsts) for k,lx in bufsts.items(): buf_size = prod(dtype.shape) if isinstance(dtype:=lx[0].src[0].dtype, ImageDType) else max(y.st_arg.real_size() for y in lx) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index a663e9d8c5..ce7b7e7ad1 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -123,8 +123,8 @@ REDUCE_ALU: Dict[ReduceOps, BinaryOps] = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps # https://en.wikipedia.org/wiki/Identity_element def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(dt)}[op], dt) -# the order of these UOps controls the order of the toposort -class UOps(FastEnum): +# the order of these Ops controls the order of the toposort +class Ops(FastEnum): # uops that aren't rendered SINK = auto() CONTIGUOUS = auto() @@ -183,9 +183,9 @@ class UOps(FastEnum): VCONST = auto() CONST = auto() -BUFFER_UOPS = {UOps.LOAD, UOps.PRELOAD, UOps.STORE, UOps.VALID} +BUFFER_UOPS = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID} COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR} -END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)} +END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)} # With True as the default, this matches the old symbolic behavior def resolve(x, default:bool=True): @@ -218,14 +218,14 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s class UOpMetaClass(type): ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary() - def __call__(cls, op:UOps, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None): + def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None): if (ret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret UOpMetaClass.ucache[key] = ret = super().__call__(op, dtype, src, arg) return ret class UOp(MathTrait, metaclass=UOpMetaClass): __slots__ = ["op", "dtype", "src", "arg"] - def __init__(self, op:UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None): + def __init__(self, op:Ops, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None): # TODO: instant check rules here make debugging easier #assert op in UOps and isinstance(dtype, DType), f"bad UOp creation with {op} {dtype}" #if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool @@ -243,7 +243,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def key(self) -> bytes: return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest() def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))") - def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg + def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else self.arg @functools.cached_property def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}} @functools.cached_property # parents with self @@ -251,31 +251,31 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @functools.cached_property def tuplize(self:UOp) -> Tuple[int, Any, Optional[DType], Tuple]: - return (self.op.value, self.arg.value if self.op is UOps.ALU else self.arg, self.dtype, tuple(x.tuplize for x in self.src)) + return (self.op.value, self.arg.value if self.op is Ops.ALU else self.arg, self.dtype, tuple(x.tuplize for x in self.src)) # *** uop shape stuff *** @property - def has_st(self) -> bool: return self.op not in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL, UOps.BUFFER, UOps.CONST, UOps.DEFINE_VAR} + def has_st(self) -> bool: return self.op not in {Ops.DEFINE_LOCAL, Ops.DEFINE_GLOBAL, Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR} @functools.cached_property def st(self) -> Optional[ShapeTracker]: if not self.has_st: return None if self.op in BUFFER_UOPS: return self.st_arg - if self.op is UOps.VIEW: return self.arg + if self.op is Ops.VIEW: return self.arg src_sts = [x.st for x in self.src if x.st is not None] assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}" from tinygrad.shape.shapetracker import ShapeTracker - return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg)) if self.op is UOps.REDUCE_AXIS else src_sts[0] + return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg)) if self.op is Ops.REDUCE_AXIS else src_sts[0] @functools.cached_property def full_shape(self) -> Tuple[sint, ...]: - return self.arg.shape if self.op is UOps.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st])) + return self.arg.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st])) # *** uop evaluation *** def simplify(self): with Context(TRACK_MATCH_STATS=0): return graph_rewrite(self, symbolic) - def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is UOps.CONST else ret + def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret def _eval(self, dtype, expected_type:Type[T]) -> T: assert self.dtype in dtype, f"eval with wrong dtype {self}" vmin, vmax = (simple_self:=self.simplify())._min_max @@ -294,98 +294,98 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @property def st_arg(self) -> ShapeTracker: assert self.op in BUFFER_UOPS, f"st_arg called on {self.op}" - ret = self.src[0 if self.op is UOps.VALID else 1] - assert ret.op is UOps.VIEW, f"st_arg trying to return {ret}" + ret = self.src[0 if self.op is Ops.VALID else 1] + assert ret.op is Ops.VIEW, f"st_arg trying to return {ret}" return ret.arg @property def axis_arg(self) -> Tuple[int, ...]: - assert self.op in {UOps.REDUCE_AXIS, UOps.WMMA}, f"axis_arg called on {self.op}" - ret = self.arg[1] if self.op is UOps.REDUCE_AXIS else self.arg[7] + assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}" + ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7] assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}" return ret - def sink(self, *srcs:UOp): return UOp(UOps.SINK, dtypes.void, (self,)+srcs) - def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(UOps.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx)) - def view(self, st:ShapeTracker): return UOp(UOps.VIEW, self.dtype, (self,), st) + def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs) + def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx)) + def view(self, st:ShapeTracker): return UOp(Ops.VIEW, self.dtype, (self,), st) def const_like(self, b:ConstLike): return UOp.const(self.dtype, b) def broadcast(self, count:int): assert self.dtype.count == 1 if count == 1: return self - return UOp(UOps.VECTORIZE, self.dtype.vec(count), (self,)*count) - def cast(self, dtype:DType): return UOp(UOps.CAST, dtype, (self,)) - def bitcast(self, dtype:DType): return UOp(UOps.BITCAST, dtype, (self,)) + return UOp(Ops.VECTORIZE, self.dtype.vec(count), (self,)*count) + def cast(self, dtype:DType): return UOp(Ops.CAST, dtype, (self,)) + def bitcast(self, dtype:DType): return UOp(Ops.BITCAST, dtype, (self,)) def gep(self, i:Union[Tuple[int, ...], int]): if isinstance(i, int): # NOTE: these are just shortcuts to not have to create and fold later - if self.op is UOps.VECTORIZE: return self.src[i] - if self.op is UOps.VCONST: return UOp.const(self.dtype.scalar(), self.arg[i]) - if self.op is UOps.CONST: return UOp.const(self.dtype.scalar(), self.arg) + if self.op is Ops.VECTORIZE: return self.src[i] + if self.op is Ops.VCONST: return UOp.const(self.dtype.scalar(), self.arg[i]) + if self.op is Ops.CONST: return UOp.const(self.dtype.scalar(), self.arg) i = (i,) if self.dtype == dtypes.void or (i == tuple(range(len(i))) and self.dtype.vcount == len(i)): return self assert len(i) >= 1 and all(x < self.dtype.vcount for x in i), f"bad GEP on {self.dtype}, {i}" - return UOp(UOps.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i) - def load(self, *src:UOp, **kwargs): return UOp(UOps.LOAD, src=(self,)+src, **kwargs) - def store(self, *src:UOp, **kwargs): return UOp(UOps.STORE, dtypes.void, (self,)+src, **kwargs) + return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i) + def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, src=(self,)+src, **kwargs) + def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs) def alu(self, arg, *src:UOp): out_dtype = (self, *src)[-1].dtype if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} and out_dtype is not None: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool - return UOp(UOps.ALU, out_dtype, (self,)+src, arg) + return UOp(Ops.ALU, out_dtype, (self,)+src, arg) @staticmethod def const(dtype:DType, b:ConstLike): - if isinstance(b, UOp): return b.unbind()[0] if b.op is UOps.BIND else b + if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same - return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore + return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore @staticmethod def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int): - return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start, + return UOp(Ops.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start, UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx) - def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op) - def r(self, op, axis): return UOp(UOps.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in ReduceOps else op, axis)) - def assign(self, x:UOp): return UOp(UOps.ASSIGN, self.dtype, (self,x)) + def reduce(self, op:BinaryOps, *rng:UOp): return UOp(Ops.REDUCE, self.dtype, (self,) + rng, op) + def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in ReduceOps else op, axis)) + def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x)) # *** uop Variable stuff *** @staticmethod def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.int): assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}" - return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val)) + return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val)) @property def expr(self): - assert self.op is UOps.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR" + assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR" return self.arg[0] def bind(self, val:int): - assert self.op is UOps.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR" + assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR" assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]" - return UOp(UOps.BIND, self.dtype, (self, self.const_like(val))) + return UOp(Ops.BIND, self.dtype, (self, self.const_like(val))) def unbind(self) -> Tuple[Variable, int]: - assert self.op is UOps.BIND and self.src[0].op is UOps.DEFINE_VAR and self.src[1].op is UOps.CONST, f"can't unbind {self}" + assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}" return self.src[0], self.src[1].arg @property def val(self) -> int: return self.unbind()[1] def vars(self) -> Set[UOp]: - bound_vars = set([x for x in self.sparents if x.op is UOps.BIND and x.src[0].op is UOps.DEFINE_VAR]) + bound_vars = set([x for x in self.sparents if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR]) bound_var_base = set(x.src[0] for x in bound_vars) - all_vars = set([x for x in self.sparents if x.op is UOps.DEFINE_VAR]) + all_vars = set([x for x in self.sparents if x.op is Ops.DEFINE_VAR]) return bound_vars.union(set([x for x in all_vars if x not in bound_var_base])) def variables(self) -> List[Variable]: st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS] - return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not UOps.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg) + return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg) # *** uop symbolic stuff *** def const_factor(self) -> int: """largest known int that divides self""" - if self.op is UOps.CONST: return self.arg - if self.op is UOps.VCONST: return functools.reduce(math.gcd, self.arg) - if self.op is UOps.ALU: + if self.op is Ops.CONST: return self.arg + if self.op is Ops.VCONST: return functools.reduce(math.gcd, self.arg) + if self.op is Ops.ALU: if self.arg is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor()) - if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is UOps.CONST else self.src[1].arg if self.src[1].op is UOps.CONST else 1 + if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1 return 1 def divides(self, v) -> Optional[UOp]: if v==1: return self - if self.op is UOps.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None - if self.op is UOps.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None - if self.op is UOps.ALU: + if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None + if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None + if self.op is Ops.ALU: if self.arg is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None if self.arg is BinaryOps.MUL: if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1] @@ -398,20 +398,20 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @functools.cached_property def _min_max(self) -> Tuple[ConstType, ConstType]: # NOTE: returned UOp is assumed to be CONST - if self.op is UOps.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2] - if self.op is UOps.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax - if self.op is UOps.BIND: return self.src[0].vmin, self.src[0].vmax # ignore the bound value - if self.op in {UOps.EXPAND, UOps.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src) + if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2] + if self.op is Ops.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax + if self.op is Ops.BIND: return self.src[0].vmin, self.src[0].vmax # ignore the bound value + if self.op in {Ops.EXPAND, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src) # TODO: UOps.SPECIAL is UOps.DEFINE_VAR - if self.op is UOps.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else dtypes.max(self.dtype) - if self.op is UOps.CONST: return self.arg, self.arg - if self.op is UOps.VCONST: return (min(self.arg), max(self.arg)) - if self.op is UOps.ALU and not dtypes.is_float(self.dtype): + if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else dtypes.max(self.dtype) + if self.op is Ops.CONST: return self.arg, self.arg + if self.op is Ops.VCONST: return (min(self.arg), max(self.arg)) + if self.op is Ops.ALU and not dtypes.is_float(self.dtype): s0,s1,s2 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(3)] if self.arg is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax if self.arg is BinaryOps.MUL: return min(vals:=(s0.vmin*s1.vmin, s0.vmin*s1.vmax, s0.vmax*s1.vmin, s0.vmax*s1.vmax)), max(vals) if self.arg is BinaryOps.MOD and s1.vmin > 0: return 0, s1.vmax-1 - if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST: + if self.arg is BinaryOps.IDIV and s1.op is Ops.CONST: if s1.arg > 0: return s0.vmin//s1.arg, s0.vmax//s1.arg if s1.arg < 0 and s0.vmin >= 0: return -(s0.vmax//-s1.arg), -(s0.vmin//-s1.arg) if self.arg is BinaryOps.MAX: return max(s0.vmin, s1.vmin), max(s0.vmax, s1.vmax) @@ -430,7 +430,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @functools.cached_property def _sym_fxn(self): sself = self.simplify() - varnames = tuple(x.arg[0] for x in sself.sparents if x.op is UOps.DEFINE_VAR) + varnames = tuple(x.arg[0] for x in sself.sparents if x.op is Ops.DEFINE_VAR) # TODO: sanitize varnames, or don't use naked eval while staying fast return eval("lambda "+','.join(varnames)+": "+sself.render()), varnames # pylint: disable=eval-used @@ -440,7 +440,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def render(self, simplify=True) -> str: ret = graph_rewrite(self.simplify() if simplify else self, renderer) - return ret.arg if ret.op is UOps.NOOP else str(ret) + return ret.arg if ret.op is Ops.NOOP else str(ret) @dataclass(frozen=True) class KernelInfo: @@ -476,7 +476,7 @@ def exec_alu(op:Op, dtype:DType, operands, truncate_output=True): def print_uops(uops:List[UOp]): for i,u in enumerate(uops): - formatted_parents = [uops.index(x) if x.op is not UOps.CONST else f"{x.arg}" for x in u.src] + formatted_parents = [uops.index(x) if x.op is not Ops.CONST else f"{x.arg}" for x in u.src] print(f"{i:4d} {str(u.op):20s}: {str(u.dtype):25s} " f"{str(formatted_parents):32s} {u.arg}") def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]: @@ -487,26 +487,26 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]: dont_count: Set[UOp] = set() if ignore_indexing: for u in uops: - if u.op in {UOps.LOAD, UOps.STORE}: + if u.op in {Ops.LOAD, Ops.STORE}: dont_count = dont_count.union(u.src[0].sparents) if len(u.src) > 2: dont_count = dont_count.union(u.src[2].sparents) - elif u.op is UOps.IF: + elif u.op is Ops.IF: dont_count = dont_count.union(u.src[0].sparents) for u in uops: - if u.op is UOps.RANGE: + if u.op is Ops.RANGE: mult_stack.append(mults) mults *= (u.src[1] - u.src[0]).ssimplify() - elif u.op is UOps.ENDRANGE: + elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1) - elif u.op is UOps.SPECIAL: + elif u.op is Ops.SPECIAL: mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these - elif u.op is UOps.LOAD: + elif u.op is Ops.LOAD: mem += u.dtype.itemsize * mults - elif u.op is UOps.STORE: + elif u.op is Ops.STORE: mem += u.src[1].dtype.itemsize * mults - elif u.op is UOps.ALU and u not in dont_count: + elif u.op is Ops.ALU and u not in dont_count: flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count - elif u.op is UOps.WMMA and u not in dont_count: + elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults return flops, mem @@ -525,11 +525,11 @@ def lines(fn) -> List[str]: class UPat(MathTrait): __slots__ = ["op", "dtype", "arg", "name", "src"] - def __init__(self, op:Optional[Union[UOps, Tuple[UOps, ...]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None, + def __init__(self, op:Optional[Union[Ops, Tuple[Ops, ...]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None, src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None, name:Optional[str]=None, allow_any_len:bool=False, location=None, - custom_early_reject:Optional[Set[Tuple[UOps, Any]]]=None): - self.op: Optional[Tuple[UOps, ...]] = (op,) if isinstance(op, UOps) else op + custom_early_reject:Optional[Set[Tuple[Ops, Any]]]=None): + self.op: Optional[Tuple[Ops, ...]] = (op,) if isinstance(op, Ops) else op self.dtype: Optional[Tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject self.src: Any = None @@ -561,24 +561,24 @@ class UPat(MathTrait): @staticmethod @functools.lru_cache(None) def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True): - return UPat((UOps.CONST, UOps.VCONST) if vec else UOps.CONST, dtype=dtype, name=name) + return UPat((Ops.CONST, Ops.VCONST) if vec else Ops.CONST, dtype=dtype, name=name) @staticmethod - def const(dtype:Optional[Union[DType, Tuple[DType, ...]]], b:ConstType): return UPat(UOps.CONST, dtype=dtype, arg=b) + def const(dtype:Optional[Union[DType, Tuple[DType, ...]]], b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b) # copied from UOp - def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(UOps.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx)) - def view(self, st=None, **kwargs): return UPat(UOps.VIEW, self.dtype, (self,), st, **kwargs) - def cast(self, dtype=None): return UPat(UOps.CAST, dtype, (self,)) - def bitcast(self, dtype=None): return UPat(UOps.BITCAST, dtype, (self,)) - def gep(self, i:int): return UPat(UOps.GEP, None, (self,), (i,)) - def load(self, *src:UPat, **kwargs): return UPat(UOps.LOAD, src=(self,)+src, **kwargs) - def store(self, *src:UPat, **kwargs): return UPat(UOps.STORE, dtypes.void, (self,)+src, **kwargs) - def assign(self, x:UPat): return UPat(UOps.ASSIGN, self.dtype, (self,x)) + def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx)) + def view(self, st=None, **kwargs): return UPat(Ops.VIEW, self.dtype, (self,), st, **kwargs) + def cast(self, dtype=None): return UPat(Ops.CAST, dtype, (self,)) + def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,)) + def gep(self, i:int): return UPat(Ops.GEP, None, (self,), (i,)) + def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs) + def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, dtypes.void, (self,)+src, **kwargs) + def assign(self, x:UPat): return UPat(Ops.ASSIGN, self.dtype, (self,x)) def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b)) def alu(self, arg, *src:UPat): asrc = (self,)+src - return UPat(UOps.ALU, None if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else asrc[-1].dtype, list(asrc) if arg in COMMUTATIVE else asrc, arg) + return UPat(Ops.ALU, None if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else asrc[-1].dtype, list(asrc) if arg in COMMUTATIVE else asrc, arg) def printable(self:UPat) -> str: try: return lines(self.location[0])[self.location[1]-1].strip() @@ -627,7 +627,7 @@ class PatternMatcher: def __init__(self, patterns:List[Tuple[UPat, Callable]]): self.patterns = patterns # NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher! - self.pdict: Dict[Tuple[UOps, Any], List[Tuple[UPat, Callable, Set, bool]]] = {} + self.pdict: Dict[Tuple[Ops, Any], List[Tuple[UPat, Callable, Set, bool]]] = {} # uop is required, arg is optional for p,fxn in self.patterns: assert p.op is not None @@ -745,84 +745,84 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp: # this is the matcher for the final rendered UOps # matcher functions returns True or False (or None to not match) spec = PatternMatcher([ - (UPat(UOps.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local), - (UPat(UOps.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local), - (UPat(UOps.DEFINE_ACC, src=(UPat.var("c"),), name="x", allow_any_len=True), - lambda x,c: all(y.op is UOps.RANGE for y in x.src[1:]) and c.dtype == x.dtype), - (UPat(UOps.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)), + (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local), + (UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local), + (UPat(Ops.DEFINE_ACC, src=(UPat.var("c"),), name="x", allow_any_len=True), + lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype), + (UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)), - (UPat(UOps.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype), - (UPat(UOps.SPECIAL, src=()), lambda: True), + (UPat(Ops.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype), + (UPat(Ops.SPECIAL, src=()), lambda: True), # TODO: confirm the args of both of these are shapetrackers - (UPat(UOps.VIEW, src=()), lambda: True), - (UPat(UOps.VIEW, src=(UPat(),)), lambda: True), + (UPat(Ops.VIEW, src=()), lambda: True), + (UPat(Ops.VIEW, src=(UPat(),)), lambda: True), - (UPat(UOps.VALID, dtypes.bool, (UPat(UOps.VIEW),)), lambda: True), - (UPat(UOps.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))), + (UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True), + (UPat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))), # early LOAD has a - (UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(UOps.VIEW))), lambda: True), - (UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(UOps.VIEW), UPat(UOps.STORE))), lambda: True), + (UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True), + (UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat(Ops.STORE))), lambda: True), # early STORE has a - (UPat(UOps.STORE, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(UOps.VIEW), UPat())), lambda: True), + (UPat(Ops.STORE, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat())), lambda: True), # **** new style load/store **** # INDEX is used in new style load/store - (UPat(UOps.INDEX, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat())), lambda: True), + (UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True), # LOAD takes a - (UPat(UOps.LOAD, src=(UPat((UOps.INDEX, UOps.CAST)),)), lambda: True), - (UPat(UOps.LOAD, src=(UPat((UOps.INDEX, UOps.CAST)), UPat((UOps.IF, UOps.BARRIER)))), lambda: True), - (UPat(UOps.LOAD, src=(UPat((UOps.INDEX, UOps.CAST)), UPat(name="alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype), + (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True), + (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True), + (UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(name="alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype), # STORE takes a - (UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.INDEX, UOps.CAST)), UPat())), lambda: True), - (UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.INDEX, UOps.CAST)), UPat(), UPat(dtype=dtypes.bool))), lambda: True), - (UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.INDEX, UOps.CAST)), UPat(), UPat(UOps.IF))), lambda: True), + (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat())), lambda: True), + (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(dtype=dtypes.bool))), lambda: True), + (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(Ops.IF))), lambda: True), # most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE - (UPat(UOps.ALU, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y")), arg=TernaryOps.WHERE), + (UPat(Ops.ALU, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y")), arg=TernaryOps.WHERE), lambda w,x,y: w.dtype == x.dtype == y.dtype), - (UPat(UOps.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPLT), lambda x,y: x.dtype == y.dtype), - (UPat(UOps.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPNE), lambda x,y: x.dtype == y.dtype), + (UPat(Ops.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPLT), lambda x,y: x.dtype == y.dtype), + (UPat(Ops.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPNE), lambda x,y: x.dtype == y.dtype), # and SHL/SHR, the shift distance is an int - (UPat(UOps.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHL), + (UPat(Ops.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHL), lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)), - (UPat(UOps.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHR), + (UPat(Ops.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHR), lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)), - (UPat(UOps.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), - (UPat(UOps.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)), + (UPat(Ops.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), + (UPat(Ops.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)), - (UPat(UOps.ASSIGN, src=(UPat((UOps.DEFINE_ACC, UOps.DEFINE_GLOBAL)), UPat())), lambda: True), - (UPat(UOps.ENDRANGE, dtype=dtypes.void, src=(UPat(UOps.RANGE),)), lambda: True), + (UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True), + (UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True), # all WMMA has 3 args, - (UPat(UOps.WMMA, src=(UPat(), UPat(), UPat())), lambda: True), - (UPat(UOps.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)), - (UPat(UOps.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)), + (UPat(Ops.WMMA, src=(UPat(), UPat(), UPat())), lambda: True), + (UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)), + (UPat(Ops.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)), # if has a - (UPat(UOps.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True), - (UPat(UOps.IF, dtype=dtypes.void, src=(UPat(), UPat(UOps.BARRIER))), lambda: True), - (UPat(UOps.ENDIF, dtype=dtypes.void, src=(UPat(UOps.IF),)), lambda: True), + (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True), + (UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True), + (UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True), - (UPat(UOps.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in REDUCE_ALU.values()), - (UPat(UOps.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()), - (UPat(UOps.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)), - (UPat((UOps.BITCAST, UOps.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None), - (UPat(UOps.BARRIER, dtypes.void, src=UPat(UOps.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local + (UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in REDUCE_ALU.values()), + (UPat(Ops.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()), + (UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)), + (UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None), + (UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local # NOTE: for testing, we let sinks be anything #(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True), - (UPat(UOps.SINK, dtypes.void), lambda: True), - (UPat(UOps.NOOP), lambda: True), + (UPat(Ops.SINK, dtypes.void), lambda: True), + (UPat(Ops.NOOP), lambda: True), # PTX LOAD/STORE - (UPat((UOps.LOAD, UOps.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True), - (UPat(UOps.BARRIER, dtypes.void, src=UPat(UOps.STORE, src=(UPat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True), + (UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True), + (UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True), ]) def type_verify(uops:List[UOp]): @@ -842,7 +842,7 @@ def cast_float_to_bf16(x: UOp) -> UOp: # *** most of symbolic lives here now *** def split_uop(x:UOp, sep:BinaryOps): - if x.op is UOps.ALU and x.arg is sep: + if x.op is Ops.ALU and x.arg is sep: for s in x.src: yield from split_uop(s, sep) else: yield x @@ -859,7 +859,7 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]: assert divides is not None remainder.append(divides) something_changed = True - elif u.op is UOps.ALU and u.arg is BinaryOps.MOD and (s1:=u.src[1]).op is UOps.CONST and s1.arg%c == 0: + elif u.op is Ops.ALU and u.arg is BinaryOps.MOD and (s1:=u.src[1]).op is Ops.CONST and s1.arg%c == 0: remainder.append(u.src[0]) something_changed = True else: remainder.append(u) @@ -874,7 +874,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]: quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1 for u in split_uop(x, BinaryOps.ADD): - if u.op is UOps.CONST: + if u.op is Ops.CONST: # add all const together first if rem_const != 0: something_changed = True rem_const += u.arg @@ -886,7 +886,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]: something_changed = True else: # divisor is the smallest common divisor of all MULs - if u.op is UOps.ALU and u.arg is BinaryOps.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor + if u.op is Ops.ALU and u.arg is BinaryOps.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor remainder.append(u) gcd = math.gcd(gcd, factor) @@ -917,11 +917,11 @@ def fold_unrolled_divs(divs:UOp): # example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x add_chain, denominator, seen_const, ans = list(split_uop(divs, BinaryOps.ADD)), None, [], None for u in add_chain: - if not (u.op is UOps.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is UOps.CONST): return None + if not (u.op is Ops.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is Ops.CONST): return None if denominator is None: denominator = u.src[1].arg if denominator != u.src[1].arg: return None # assumed CONST is the last of an ADD - if (s0:=u.src[0]).op is UOps.ALU and s0.arg is BinaryOps.ADD and s0.src[1].op is UOps.CONST and s0.src[1].op is UOps.CONST: + if (s0:=u.src[0]).op is Ops.ALU and s0.arg is BinaryOps.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST: seen_const.append(s0.src[1].arg) s0 = s0.src[0] else: seen_const.append(0) @@ -933,7 +933,7 @@ def fold_unrolled_divs(divs:UOp): if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i) return ans if ans is not None and sorted(seen_const)==list(range(denominator)) else None -def is_irreducible(u:UOp): return u.op in (UOps.CONST, UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE) +def is_irreducible(u:UOp): return u.op in (Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE) def canonicalize_simplex(X:UOp) -> Optional[UOp]: # (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints. @@ -941,7 +941,7 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]: changed, ret = False, [] for u in split_uop(X, BinaryOps.ADD): # assumed the const is the last src of MUL - if u.op is UOps.ALU and u.arg is BinaryOps.MUL and u.src[1].op is UOps.CONST and u.src[1].arg > 0: + if u.op is Ops.ALU and u.arg is BinaryOps.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0: changed = True u = u.src[0] if not (is_irreducible(u) and u.vmin >= 0): return None @@ -951,8 +951,8 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]: def is_increasing(f:UOp) -> bool: # is f a monotonically increasing function regards its input if is_irreducible(f): return True - if f.op is UOps.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1]) - if f.op is UOps.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is UOps.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0]) + if f.op is Ops.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1]) + if f.op is Ops.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0]) return False # False if not sure def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]: @@ -960,10 +960,10 @@ def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]: # if it's X >= c, returns X, False, c # (X < c).ne(True) -> X >= c - if valid.op is UOps.ALU and valid.arg is BinaryOps.CMPNE and valid.src[1].op is UOps.CONST and valid.src[1].arg == 1 and \ - (s0:=valid.src[0]).op is UOps.ALU and s0.arg is BinaryOps.CMPLT and s0.src[1].op is UOps.CONST: return s0.src[0], False, s0.src[1].arg + if valid.op is Ops.ALU and valid.arg is BinaryOps.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \ + (s0:=valid.src[0]).op is Ops.ALU and s0.arg is BinaryOps.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg # X < c -> X <= c-1 - if valid.op is UOps.ALU and valid.arg is BinaryOps.CMPLT and valid.src[1].op is UOps.CONST: return valid.src[0], True, valid.src[1].arg-1 + if valid.op is Ops.ALU and valid.arg is BinaryOps.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1 raise ValueError(f"not able to parse {valid=}") def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]: @@ -983,7 +983,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]: # every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop candidates = [] - if expr.op is UOps.ALU and expr.arg is BinaryOps.ADD and all(is_irreducible(u) and v[0] == 1 for u in split_uop(expr, BinaryOps.ADD)): + if expr.op is Ops.ALU and expr.arg is BinaryOps.ADD and all(is_irreducible(u) and v[0] == 1 for u in split_uop(expr, BinaryOps.ADD)): # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, BinaryOps.ADD)]) # try checking the whole clause @@ -992,7 +992,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]: for candidate in candidates: # if every branch in candidate gives the same simplified uop, we can rewrite the uop newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate] - if uop.op is UOps.VECTORIZE and len(uop.src) == 2: + if uop.op is Ops.VECTORIZE and len(uop.src) == 2: if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1])) if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1])) elif all_same(newuops): uop = newuops[0] @@ -1037,17 +1037,17 @@ symbolic_simple = PatternMatcher([ # NOTE: this can be wrong for loaded NaN (UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)), # ** constant folding ** - (UPat(UOps.ALU, name="root", src=UPat((UOps.VCONST, UOps.CONST))), + (UPat(Ops.ALU, name="root", src=UPat((Ops.VCONST, Ops.CONST))), lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src], truncate_output=False))), # ** COMMUTATIVE flipping ** - *[(UPat(UOps.ALU, arg=op, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None) for op in COMMUTATIVE], + *[(UPat(Ops.ALU, arg=op, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None) for op in COMMUTATIVE], # bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly (UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y), (UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y), (UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y), # *** cast *** - (UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)), - (UPat(UOps.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None), + (UPat(Ops.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)), + (UPat(Ops.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None), ]) symbolic = symbolic_simple+PatternMatcher([ @@ -1063,7 +1063,7 @@ symbolic = symbolic_simple+PatternMatcher([ (UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val), (UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1), # ALU min==max -> CONST (slow!) - (UPat(UOps.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), + (UPat(Ops.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None), # max folding (UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None), # TODO: why does this rule break beautiful_mnist? @@ -1090,11 +1090,11 @@ symbolic = symbolic_simple+PatternMatcher([ (((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)), lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None), # ** move add/mul consts to end (NOTE: this is still happening before constant folding) ** - (UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1), - (UPat(UOps.ALU, arg=BinaryOps.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1), + (UPat(Ops.ALU, arg=BinaryOps.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1), + (UPat(Ops.ALU, arg=BinaryOps.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1), # *** rules from symbolic *** # unrolled arange div folding - (UPat(UOps.ALU, name="divs", src=[UPat(), UPat(UOps.ALU, arg=BinaryOps.IDIV)], arg=BinaryOps.ADD), fold_unrolled_divs), + (UPat(Ops.ALU, name="divs", src=[UPat(), UPat(Ops.ALU, arg=BinaryOps.IDIV)], arg=BinaryOps.ADD), fold_unrolled_divs), # generic lt folding (UPat.var("x", dtypes.sints).lt(UPat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None), # canonicalize a simplex with positive coefficients > 0 @@ -1115,23 +1115,23 @@ symbolic_flat = symbolic+PatternMatcher([ ((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c), ]) -_substitute = PatternMatcher([(UPat(tuple(UOps), name="x"), lambda ctx,x: ctx.get(x,None))]) +_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) # for debug syms = { BinaryOps.ADD: "+", BinaryOps.SUB: "-", BinaryOps.IDIV: "//", BinaryOps.MOD: "%", BinaryOps.SHL: "<<", BinaryOps.SHR: ">>", BinaryOps.MUL: "*", BinaryOps.CMPLT: "<", BinaryOps.CMPNE: "!=", BinaryOps.AND: "&", BinaryOps.OR: "|", BinaryOps.XOR: "^"} renderer = PatternMatcher([ - (UPat((UOps.DEFINE_VAR, UOps.SPECIAL), name="x"), lambda x: UOp(UOps.NOOP, arg=x.arg[0])), - (UPat(UOps.RANGE, name="x"), lambda x: UOp(UOps.NOOP, arg=f"ridx{x.arg[0]}")), - (UPat(UOps.CONST, name="x"), lambda x: UOp(UOps.NOOP, arg=str(x.arg))), - (UPat(UOps.BIND, src=UPat(UOps.NOOP), name="x"), lambda x: x.src[0]), - (UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=UnaryOps.NEG), lambda x: UOp(UOps.NOOP, arg=f"(-{x.src[0].arg})")), - (UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=BinaryOps.MAX), lambda x: UOp(UOps.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")), - (UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=TernaryOps.MULACC), - lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")), - (UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=TernaryOps.WHERE), - lambda x: UOp(UOps.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")), - (UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")), + (UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])), + (UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")), + (UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))), + (UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]), + (UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=UnaryOps.NEG), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")), + (UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=BinaryOps.MAX), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")), + (UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=TernaryOps.MULACC), + lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")), + (UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=TernaryOps.WHERE), + lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")), + (UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")), ]) # *** what was symbolic.py *** diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index a109167c8b..27676a5250 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -2,7 +2,7 @@ from typing import Optional, List, Tuple, Dict, Callable, Any import functools from dataclasses import dataclass, field from tinygrad.helpers import to_function_name, dedup, prod -from tinygrad.ops import Op, UOps, UOp, flops_mem, sym_infer, sint, Variable +from tinygrad.ops import Op, Ops, UOp, flops_mem, sym_infer, sint, Variable from tinygrad.dtype import DType @dataclass(frozen=True) @@ -42,10 +42,10 @@ class Program: if not self._ran_post_init and self.uops is not None: # single pass through the uops for u in self.uops: - if u.op is UOps.DEFINE_VAR: self.vars.append(u) - if u.op is UOps.DEFINE_GLOBAL: self.globals.append(u.arg) - if u.op is UOps.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL]) - if u.op is UOps.SPECIAL: + if u.op is Ops.DEFINE_VAR: self.vars.append(u) + if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg) + if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is Ops.DEFINE_GLOBAL]) + if u.op is Ops.SPECIAL: # NOTE: you have to set local_size and global_size to the base [1,1,1] outside this if u.arg[0][0] == 'i': self.local_size = None special_size = self.local_size if u.arg[0][0] == 'l' else self.global_size diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 003221695d..95ea4dddc9 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -2,66 +2,66 @@ from __future__ import annotations from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast import os, math from collections import defaultdict, Counter -from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOps, UOp, PatternMatcher, UPat, cast_float_to_bf16 +from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16 from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType from tinygrad.renderer import Renderer, TensorCore base_rewrite = PatternMatcher([ - (UPat(UOps.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]), - (UPat(UOps.ASSIGN, name="x"), lambda ctx,x: f"{ctx[x.src[0]]} = {ctx[x.src[1]]};"), - (UPat(UOps.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"), - (UPat((UOps.ENDIF, UOps.ENDRANGE)), lambda ctx: "}"), - (UPat(UOps.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"), + (UPat(Ops.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]), + (UPat(Ops.ASSIGN, name="x"), lambda ctx,x: f"{ctx[x.src[0]]} = {ctx[x.src[1]]};"), + (UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"), + (UPat((Ops.ENDIF, Ops.ENDRANGE)), lambda ctx: "}"), + (UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"), # r method accesses - (UPat(UOps.RANGE, name="x"), + (UPat(Ops.RANGE, name="x"), lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = {ctx[x.src[0]]}; {ctx[x]} < {ctx[x.src[1]]}; {ctx[x]}++) {{"), - (UPat(UOps.VECTORIZE, name="x"), + (UPat(Ops.VECTORIZE, name="x"), lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \ (f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device == "CLANG" else f"({','.join([ctx[y] for y in x.src])})")), - (UPat(UOps.CAST, name="x"), lambda ctx,x: f"({ctx.render_dtype(x.dtype)})({ctx[x.src[0]]})"), - (UPat(UOps.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"), - (UPat(UOps.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.arg[1]}];"), - (UPat(UOps.BARRIER), lambda ctx: ctx.barrier), - (UPat(UOps.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]), - (UPat(UOps.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"), + (UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_dtype(x.dtype)})({ctx[x.src[0]]})"), + (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"), + (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.arg[1]}];"), + (UPat(Ops.BARRIER), lambda ctx: ctx.barrier), + (UPat(Ops.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]), + (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"), # const - (UPat(UOps.CONST, arg=math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)}){ctx.infinity})"), - (UPat(UOps.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)})-{ctx.infinity})"), - (UPat(UOps.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){ctx.nan})" if math.isnan(x.arg) else None), - (UPat(UOps.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"), - (UPat(UOps.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"), - (UPat(UOps.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"), - (UPat(UOps.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"), - (UPat(UOps.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"), + (UPat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)}){ctx.infinity})"), + (UPat(Ops.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)})-{ctx.infinity})"), + (UPat(Ops.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){ctx.nan})" if math.isnan(x.arg) else None), + (UPat(Ops.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"), + (UPat(Ops.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"), + (UPat(Ops.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"), + (UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"), + (UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"), # consts are rendered to larger type and casted - (UPat(UOps.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}f)"), - (UPat(UOps.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}u)"), - (UPat(UOps.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg})"), + (UPat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}f)"), + (UPat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}u)"), + (UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg})"), # default const render - (UPat(UOps.CONST, name="x"), lambda ctx,x: str(x.arg)), + (UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)), # new load/store - (UPat(UOps.INDEX, src=(UPat.var("buf"), UPat.var('idx'))), + (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))), lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == BinaryOps.ADD else ctx[idx]})"), - (UPat(UOps.LOAD, src=(UPat.var('bidx'), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"), - (UPat(UOps.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"), - (UPat(UOps.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"), + (UPat(Ops.LOAD, src=(UPat.var('bidx'), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"), + (UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"), + (UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"), # alu/gep - (UPat(UOps.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.arg]( + (UPat(Ops.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.arg]( *([strip_parens(ctx[v]) if v.arg == x.arg and x.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else ctx[v] for v in x.src]), x.dtype)), - (UPat(UOps.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \ + (UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \ (f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")), ]) extra_pm = PatternMatcher([ # insert a NOOP before BITCAST to force it to be rendered. not needed on all backends? - (UPat(UOps.BITCAST, name="x"), - lambda x: UOp(UOps.BITCAST, x.dtype, (UOp(UOps.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not UOps.NOOP else None), + (UPat(Ops.BITCAST, name="x"), + lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None), # gate any stores that aren't gated with ifs - (UPat(UOps.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"), - lambda store: UOp(UOps.STORE, src=store.src[:2]+(UOp(UOps.IF, src=(store.src[2],)),))), + (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"), + lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))), # rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends) - (UPat(UOps.ALU, name="m", arg=BinaryOps.MAX), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])), + (UPat(Ops.ALU, name="m", arg=BinaryOps.MAX), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])), ]) def uops_to_dtypes(uops:List[UOp]) -> List[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType))) @@ -125,41 +125,41 @@ class CStyleLanguage(Renderer): depth = 1 c: DefaultDict[str, int] = defaultdict(int) for u in uops: - if u.op in (UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR): - r[u] = f"data{u.arg}" if u.op is UOps.DEFINE_GLOBAL else u.arg[0] + if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR): + r[u] = f"data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else u.arg[0] bufs[u] = (r[u], (u.dtype, False)) continue # mark buffers that we store to writable - if u.op is UOps.STORE: + if u.op is Ops.STORE: for up in u.src[0].sparents: - if up.op is UOps.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True)) + if up.op is Ops.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True)) # naming prefix = None - if u.op is UOps.SPECIAL: + if u.op is Ops.SPECIAL: r[u] = u.arg[0] else: - prefix = {UOps.RANGE: "ridx", UOps.ALU: "alu", UOps.WMMA: "wmma", UOps.DEFINE_LOCAL: "temp", UOps.CONST: "const", - UOps.CAST: "cast", UOps.BITCAST: "cast", UOps.GEP: "gep", UOps.VECTORIZE: "cast", UOps.NOOP: "precast", - UOps.INDEX: "bidx", UOps.DEFINE_ACC: "acc", UOps.LOAD: "val"}.get(u.op, "unk") + prefix = {Ops.RANGE: "ridx", Ops.ALU: "alu", Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const", + Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.NOOP: "precast", + Ops.INDEX: "bidx", Ops.DEFINE_ACC: "acc", Ops.LOAD: "val"}.get(u.op, "unk") r[u] = f"{prefix}{c[prefix]}" l = cast(str, self.string_rewrite.rewrite(u, ctx=self)) assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}" - if u.op in {UOps.ENDIF, UOps.ENDRANGE}: depth -= 1 - if u.op in {UOps.CONST, UOps.GEP, UOps.INDEX} or (u.op in {UOps.VECTORIZE, UOps.ALU, UOps.CAST, UOps.BITCAST} + if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1 + if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX} or (u.op in {Ops.VECTORIZE, Ops.ALU, Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA")): r[u] = l else: - if u.op in {UOps.RANGE, UOps.ASSIGN, UOps.DEFINE_LOCAL} or u.dtype == dtypes.void: - if u.op is UOps.ASSIGN: r[u] = r[u.src[0]] + if u.op in {Ops.RANGE, Ops.ASSIGN, Ops.DEFINE_LOCAL} or u.dtype == dtypes.void: + if u.op is Ops.ASSIGN: r[u] = r[u.src[0]] else: - l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not UOps.SPECIAL else "") + l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "") kernel.append(" "*depth + l) if prefix: c[prefix] += 1 # if it was used, increment - if u.op in {UOps.IF, UOps.RANGE}: depth += 1 + if u.op in {Ops.IF, Ops.RANGE}: depth += 1 del self.r # NOTE: this relies on bufs dict preserving order @@ -189,7 +189,7 @@ class ClangRenderer(CStyleLanguage): def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str: prefix = [self.render_vector_prefix(dt) for dt in uops_to_dtypes(uops) if dt.count > 1] # https://github.com/corsix/amx - for name, (N, M, _), dtype_in, _, _, _, _, _ in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): + for name, (N, M, _), dtype_in, _, _, _, _, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): prefix += [ '#define AMX_SET(imm5) __asm("nop\\nnop\\nnop\\n.word (0x201000+(%0<<5)+%1)" : : "i"(17), "i"(imm5) : "memory")', '#define AMX(op, gpr, btf) __asm(".word (0x201000+(%0 << 5)+0%1-((0%1>>4)*6))" : : "i"(op), "r"((unsigned long long)(gpr)+(btf)) : "memory")', @@ -214,13 +214,13 @@ class OpenCLRenderer(CStyleLanguage): type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong", dtypes.bfloat16: "ushort" } string_rewrite = PatternMatcher([ - (UPat(UOps.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"), + (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"), # load/store image (OpenCL) - (UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var"), UPat.var("gate"))), + (UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var"), UPat.var("gate"))), lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"), - (UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)), + (UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)), lambda ctx,buf,idx: f"read_imagef({ctx[buf]}, smp, {ctx[idx]})"), - (UPat(UOps.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var", dtypes.float.vec(4))), allow_any_len=True), + (UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var", dtypes.float.vec(4))), allow_any_len=True), lambda ctx,buf,idx,var: f"write_imagef({ctx[buf]}, {ctx[idx]}, {ctx[var]});"), ]) + base_rewrite @@ -234,13 +234,13 @@ class IntelRenderer(OpenCLRenderer): st1_pattern=(((1,0),),((1,2),(1,1),(0,0))),expanded_shape=(8,2,8)) for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]] string_rewrite = PatternMatcher([ - (UPat(UOps.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x[0]]})"), - (UPat(UOps.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x[0]]})"), + (UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x[0]]})"), + (UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x[0]]})"), ]) + OpenCLRenderer.string_rewrite def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str: prefix = [] - for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): + for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): dt_in = ("ushort", "bf16") if arg[2] == dtypes.bfloat16 else (arg[2].name, "f16") prefix.append(f"""{arg[3].name}8 __{arg[0]}({dt_in[0]}16 a, {dt_in[0]}16 b, {arg[3].name}8 c) {{ return intel_sub_group_{dt_in[1]}_{dt_in[1]}_matrix_mad_k16(as_int8(a), as_int8(b), c);\n}}""") @@ -272,17 +272,17 @@ class MetalRenderer(CStyleLanguage): # upcast to float32 all the ops that don't support bfloat16 extra_matcher = PatternMatcher([ # NOTE: this is copied from PTX - *[(UPat(UOps.ALU, arg=op, dtype=dtypes.bfloat16, name="x"), + *[(UPat(Ops.ALU, arg=op, dtype=dtypes.bfloat16, name="x"), lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))) for op in [UnaryOps.SQRT, UnaryOps.EXP2, UnaryOps.LOG2, UnaryOps.SIN]] ]) + extra_pm string_rewrite = PatternMatcher([ - (UPat(UOps.BITCAST, name="x"), lambda ctx,x: f"as_type<{ctx.render_dtype(x.dtype)}>({ctx[x.src[0]]})"), + (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_type<{ctx.render_dtype(x.dtype)}>({ctx[x.src[0]]})"), ]) + base_rewrite def render_kernel(self, function_name, kernel, bufs, uops, prefix=None): - prefix, wmma_args = ["#include ","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA]) + prefix, wmma_args = ["#include ","using namespace metal;"], set([uop.arg for uop in uops if uop.op is Ops.WMMA]) for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{ simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x; b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c); @@ -335,7 +335,7 @@ class CUDARenderer(CStyleLanguage): prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}] dt_map = { dtypes.half: "f16", dtypes.bfloat16: "bf16" } - for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): + for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes] wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)] n_operands = [size*dtype.itemsize//4 for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)] # 4 => CUDA reg size in bytes @@ -353,7 +353,7 @@ class CUDARenderer(CStyleLanguage): return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix) def get_kernel_modifier(self, uops:List[UOp]) -> str: - maxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is UOps.SPECIAL and u.arg[0][0] == "l") + maxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l") # https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html return f"__launch_bounds__({maxThreadsPerBlock}) " @@ -387,20 +387,20 @@ class AMDRenderer(CStyleLanguage): type_map = {dtypes.bfloat16: "hip_bfloat16"} extra_matcher = PatternMatcher([ # cast bfloat16 alus to float - (UPat(UOps.ALU, arg=TernaryOps.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))), - lambda b,x,y: UOp(UOps.ALU, arg=TernaryOps.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)), - (UPat(UOps.ALU, dtype=dtypes.bfloat16, name="x"), + (UPat(Ops.ALU, arg=TernaryOps.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))), + lambda b,x,y: UOp(Ops.ALU, arg=TernaryOps.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)), + (UPat(Ops.ALU, dtype=dtypes.bfloat16, name="x"), lambda x: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)), - (UPat(UOps.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))), + (UPat(Ops.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))), lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg)), # add float intermediate casting for bfloat16 - (UPat(UOps.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None), - (UPat(UOps.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None), + (UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None), + (UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None), # bfloat16 casting (UPat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))), - (UPat(UOps.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)), + (UPat(Ops.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)), lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)), - (UPat(UOps.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm + (UPat(Ops.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm def render_vector_prefix(self, dtype:DType) -> str: vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar()) @@ -414,7 +414,7 @@ class AMDRenderer(CStyleLanguage): if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("struct hip_bfloat16 { unsigned short data; };") prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count > 1] - for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper + for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32") else: prefix.append(f"static inline __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) { half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; } @@ -423,7 +423,7 @@ class AMDRenderer(CStyleLanguage): return super().render_kernel(function_name, kernel, bufs, uops, prefix) def get_kernel_modifier(self, uops:List[UOp]) -> str: - requiredMaxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is UOps.SPECIAL and u.arg[0][0] == "l") + requiredMaxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l") # https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size # NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))" diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 0ea03f008b..987f3a7259 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -1,7 +1,7 @@ from typing import Dict, Callable, List, Optional from llvmlite import ir from tinygrad.dtype import DType, PtrDType, dtypes -from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps, UOps, UOp +from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps, Ops, UOp from tinygrad.renderer import Renderer MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf @@ -71,7 +71,7 @@ class LLVMRenderer(Renderer): module = ir.Module(name=__file__) # extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order) - buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}} + buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}} buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())} # create llvm function @@ -90,14 +90,14 @@ class LLVMRenderer(Renderer): for u in uops: uop,dtype,src,args = u.op,u.dtype,u.src,u.arg - if uop is UOps.INDEX: + if uop is Ops.INDEX: lvars[u] = bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True) - elif uop is UOps.STORE: + elif uop is Ops.STORE: if len(src) > 2: with bb[-1].if_then(lvars[src[2]]): bb[-1].store(lvars[src[1]], lvars[src[0]]) else: bb[-1].store(lvars[src[1]], lvars[src[0]]) - elif uop is UOps.ENDRANGE: + elif uop is Ops.ENDRANGE: loop_entry_bb, phis = loop_blocks.pop() idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1)) lvars[src[0]].add_incoming(idx_p1, bb[-1].block) @@ -105,7 +105,7 @@ class LLVMRenderer(Renderer): bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}"))) bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block) else: - if uop is UOps.RANGE: + if uop is Ops.RANGE: bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}"))) bb[-2].branch(bb[-1].block) @@ -119,10 +119,10 @@ class LLVMRenderer(Renderer): lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}") lvars[u].add_incoming(lvars[src[0]], bb[-2].block) loop_blocks.append((bb[-1].block, phis)) - elif uop is UOps.DEFINE_ACC: + elif uop is Ops.DEFINE_ACC: lvars[u] = const(src[0].arg, dtype) reduce_phis.append(u) - elif uop is UOps.LOAD: + elif uop is Ops.LOAD: if len(src) > 1: with bb[-1].if_else(lvars[src[2]]) as (then, otherwise): with then: @@ -135,17 +135,17 @@ class LLVMRenderer(Renderer): else: val = bb[-1].load(lvars[src[0]]) lvars[u] = val - elif uop is UOps.ASSIGN: + elif uop is Ops.ASSIGN: lvars[u] = lvars[src[1]] # ASSIGN UOps can link to other ASSIGN Uops, backtrace this to DEFINE_ACC backward = src[0] - while backward.op is UOps.ASSIGN: backward = backward.src[0] + while backward.op is Ops.ASSIGN: backward = backward.src[0] lvars[backward] = lvars[u] - elif uop is UOps.ALU: + elif uop is Ops.ALU: lvars[u] = self.code_for_op[args](bb[-1], *[lvars[x] for x in src], src[0].dtype if args in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype) - elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is UOps.BITCAST) - elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]] - elif uop is UOps.CONST: lvars[u] = const(args, dtype) + elif uop in {Ops.CAST, Ops.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is Ops.BITCAST) + elif uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]] + elif uop is Ops.CONST: lvars[u] = const(args, dtype) else: raise RuntimeError(f"failed to render {uop}") bb[-1].ret_void() diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index ccdf50eddb..78470e93dd 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -1,7 +1,7 @@ from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable import struct from collections import defaultdict -from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat +from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, Ops, UOp, PatternMatcher, UPat from tinygrad.dtype import dtypes, DType, PtrDType, ConstType from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import CUDARenderer @@ -38,20 +38,20 @@ ptx_matcher = PatternMatcher([ (UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y), (UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y), # upcast to float32 all the ops that don't support half - *[(UPat(UOps.ALU, arg=op, dtype=dtypes.half, name="x"), + *[(UPat(Ops.ALU, arg=op, dtype=dtypes.half, name="x"), lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half))) for op in asm_for_op.keys() if op not in supports_half], # load/store bool -> uint8 - (UPat(UOps.LOAD, dtypes.bool, src=(UPat(dtype=dtypes.int64),), name="x", allow_any_len=True), + (UPat(Ops.LOAD, dtypes.bool, src=(UPat(dtype=dtypes.int64),), name="x", allow_any_len=True), lambda x: UOp(x.op, dtypes.uint8, x.src[0:1] + ((x.src[1].cast(dtypes.uint8),) if len(x.src) >= 2 else ()) + x.src[2:]).cast(dtypes.bool)), - (UPat(UOps.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True), + (UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True), lambda x: UOp(x.op, dtypes.void, x.src[0:1] + (x.src[1].cast(dtypes.uint8),) + x.src[2:])), # load/store use pointer arithmetic, and the cast does nothing - (UPat(UOps.INDEX, src=(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize), - (UPat(UOps.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None), + (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize), + (UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None), # ptx shr and shl instructions require y to be uint - (UPat.var("x") << UPat.var("y"), lambda x,y: UOp(UOps.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHL) if y.dtype != dtypes.uint else None), - (UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(UOps.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHR) if y.dtype != dtypes.uint else None), + (UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHL) if y.dtype != dtypes.uint else None), + (UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHR) if y.dtype != dtypes.uint else None), ]) class PTXRenderer(Renderer): @@ -146,51 +146,51 @@ class PTXRenderer(Renderer): for u in uops: uop,dtype,src,args = u.op,u.dtype,u.src,u.arg - if uop is UOps.IF: + if uop is Ops.IF: pred_reg = _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True) kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{uops.index(u)}", pred_reg, invert=True)) - elif uop is UOps.BARRIER and self.barrier: kk(self.barrier) - elif uop is UOps.ENDRANGE: + elif uop is Ops.BARRIER and self.barrier: kk(self.barrier) + elif uop is Ops.ENDRANGE: kk(self.code_for_op[BinaryOps.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]), self.code_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int])) kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred)) - elif uop is UOps.ENDIF: + elif uop is Ops.ENDIF: kk(f"IF_{r[src[0].src[0]][1:]}_{uops.index(src[0])}:") - elif uop is UOps.STORE: + elif uop is Ops.STORE: assert src[0].dtype == dtypes.int64, "store isn't int64" - mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global' - gate = f"@{r[src[2]]} " if len(src)>2 and src[2].op is not UOps.IF else "" + mem_type = '.shared' if src[0].op is Ops.DEFINE_LOCAL or any(x.op is Ops.DEFINE_LOCAL for x in src[0].parents) else '.global' + gate = f"@{r[src[2]]} " if len(src)>2 and src[2].op is not Ops.IF else "" if src[1].dtype.count > 1: kk(gate + f"st{mem_type}.v{src[1].dtype.count}.{self.mem_types[src[1].dtype.scalar()]} [{r[src[0]]}+0], {{{', '.join(r[src[1]])}}};") else: kk(gate + f"st{mem_type}.{self.mem_types[src[1].dtype]} [{r[src[0]]}+0], {r[src[1]]};") else: - if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:])) - elif uop is UOps.ALU: + if uop is Ops.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:])) + elif uop is Ops.ALU: src_dtype = src[0].dtype if args in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype kk(self.code_for_op[args](ssa("alu", u), *[r[x] for x in src], src_dtype, self.types[src_dtype])) - elif uop is UOps.DEFINE_ACC: + elif uop is Ops.DEFINE_ACC: if dtype.count > 1: r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)] for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].src[0].arg, dtype.scalar())};") else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {ssa('acc', u)}, {const(src[0].arg, dtype)};") - elif uop is UOps.SPECIAL: + elif uop is Ops.SPECIAL: assert args[0][0] != "i", "idx not supported" kk(f"mov.u32 %{args[0]}, %{'ctaid' if args[0][0] == 'g' else 'tid'}.{chr(120+int(args[0][-1]))};") r[u] = "%" + args[0] kernel = [f".reg .u32 %{args[0]};"] + kernel - elif uop is UOps.DEFINE_VAR: + elif uop is Ops.DEFINE_VAR: bufs.append((args[0], dtype)) r[u] = f"%{args[0]}" kk(*self.render_load(args[0], ssa('dat', u, self.types[dtype]), dtype, ss=".param")) - elif uop is UOps.CONST: r[u] = const(args, dtype, mov=True) - elif uop is UOps.GEP: + elif uop is Ops.CONST: r[u] = const(args, dtype, mov=True) + elif uop is Ops.GEP: assert len(u.arg) == 1 r[u] = r[src[0]][u.arg[0]] - elif uop is UOps.LOAD: + elif uop is Ops.LOAD: assert src[0].dtype == dtypes.int64, "load isn't int64" - mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global' - has_gate = len(src) > 2 and src[2].op is UOps.ALU + mem_type = '.shared' if src[0].op is Ops.DEFINE_LOCAL or any(x.op is Ops.DEFINE_LOCAL for x in src[0].parents) else '.global' + has_gate = len(src) > 2 and src[2].op is Ops.ALU if dtype.count > 1: r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)] if has_gate: @@ -200,25 +200,25 @@ class PTXRenderer(Renderer): else: kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if has_gate else None, alt=r[src[1]] if has_gate else None, ss=mem_type, offset=0)) - elif uop is UOps.ASSIGN: + elif uop is Ops.ASSIGN: if dtype.count > 1: for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};") else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {r[src[0]]}, {r[src[1]]};") r[u] = r[src[0]] # NOTE: casting to str is fine because you can't vectorize a vectorize - elif uop is UOps.VECTORIZE: r[u] = [cast(str,r[x]) for x in src] - elif uop in {UOps.CAST, UOps.BITCAST}: - _cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u) - elif uop is UOps.DEFINE_LOCAL: + elif uop is Ops.VECTORIZE: r[u] = [cast(str,r[x]) for x in src] + elif uop in {Ops.CAST, Ops.BITCAST}: + _cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is Ops.BITCAST, u=u) + elif uop is Ops.DEFINE_LOCAL: # TODO: we should sum these, and fetch 0xC000 from somewhere assert args[1]*dtype.itemsize <= 0xC000, "too large local" kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype)) - elif uop is UOps.DEFINE_GLOBAL: + elif uop is Ops.DEFINE_GLOBAL: bufs.append((nm:=f"data{args}", dtype)) r[u] = f"%{nm}" dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param")) - elif uop is UOps.WMMA: + elif uop is Ops.WMMA: _, (N, M, K), dtype_in, _, _, _, upcast_axes, _ = args wmma, n_operands = [], tuple(prod(sz for _, sz in upc)*dtype_in.itemsize//4 for upc in upcast_axes[:2]) dt_map = { dtypes.half: "f16" } diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 1b87f92843..8b363a0912 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -7,7 +7,7 @@ import pickle, base64, itertools, time, struct from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate from tinygrad.helpers import all_same, getenv, flatten from tinygrad.device import Compiled, Compiler, Allocator -from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, UOps, UOp +from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, Ops, UOp from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer, IntelRenderer, ClangRenderer @@ -26,7 +26,7 @@ def _store(m, i, v): class PythonProgram: def __init__(self, name:str, lib:bytes): - self.uops: List[Tuple[UOps, Optional[DType], List[int], Any]] = pickle.loads(lib) + self.uops: List[Tuple[Ops, Optional[DType], List[int], Any]] = pickle.loads(lib) def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False): st = time.perf_counter() warp = list(itertools.product(*[range(x) for x in local_size[::-1]])) @@ -40,12 +40,12 @@ class PythonProgram: loop_ends: Dict[int, int] = {} while i < len(self.uops): uop, dtype, idp, arg = self.uops[i] - void_ops = {UOps.STORE, UOps.ENDRANGE, UOps.BARRIER, UOps.IF, UOps.ENDIF} - if uop is UOps.DEFINE_ACC: idp = [idp[0]] + void_ops = {Ops.STORE, Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF} + if uop is Ops.DEFINE_ACC: idp = [idp[0]] inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops] dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops] if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp) - if uop is UOps.STORE: + if uop is Ops.STORE: if len(inp) == 2: inp.append([True] * len(inp[0])) # set the gate to True if dtp[1].count > 1: for j,val in enumerate(inp[1]): @@ -56,32 +56,32 @@ class PythonProgram: if g: _store(m, o, v) i += 1 continue - if uop is UOps.ENDRANGE: + if uop is Ops.ENDRANGE: loop_ends[idp[0]] = i i = idp[0] continue - if uop in (UOps.BARRIER, UOps.IF, UOps.ENDIF): + if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF): # in the python emulator, the warp is always in sync i += 1 continue assert dtype is not None, f"{uop} is missing a dtype" dl[i] = dtype - if uop is UOps.DEFINE_GLOBAL: + if uop is Ops.DEFINE_GLOBAL: assert dtype.fmt is not None ul[i] = [pbufs.pop(0).cast(dtype.fmt)] * warp_size - elif uop is UOps.DEFINE_LOCAL: + elif uop is Ops.DEFINE_LOCAL: assert dtype.fmt is not None lbuf = memoryview(bytearray(arg[1]*dtype.itemsize)) ul[i] = [lbuf.cast(dtype.fmt)] * warp_size - elif uop is UOps.DEFINE_VAR: + elif uop is Ops.DEFINE_VAR: ul[i] = [pvals.pop(0)] * warp_size - elif uop is UOps.SPECIAL: + elif uop is Ops.SPECIAL: if arg[0][0] == 'g': ul[i] = [idxs[2-int(arg[0][-1])]] * warp_size elif arg[0][0] == 'l': ul[i] = [x[2-int(arg[0][-1])] for x in warp] - elif uop is UOps.CONST: ul[i] = [arg] * warp_size - elif uop is UOps.DEFINE_ACC: + elif uop is Ops.CONST: ul[i] = [arg] * warp_size + elif uop is Ops.DEFINE_ACC: ul[i] = [[inp[0][0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size - elif uop is UOps.INDEX: + elif uop is Ops.INDEX: ret = [] if isinstance(dtp[0], ImageDType): for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]): @@ -90,9 +90,9 @@ class PythonProgram: else: for m,o in zip(inp[0], inp[1]): ret.append((m,o)) ul[i] = ret - elif uop is UOps.CAST and isinstance(dtype, PtrDType): + elif uop is Ops.CAST and isinstance(dtype, PtrDType): ul[i] = inp[0] - elif uop is UOps.RANGE: + elif uop is Ops.RANGE: if i not in ul: ul[i] = [inp[0][0]] * warp_size else: for j in range(len(ul[i])): @@ -101,11 +101,11 @@ class PythonProgram: del ul[i] i = loop_ends[i] + 1 continue - elif uop is UOps.VECTORIZE: ul[i] = inp - elif uop in {UOps.CAST, UOps.BITCAST}: + elif uop is Ops.VECTORIZE: ul[i] = inp + elif uop in {Ops.CAST, Ops.BITCAST}: assert dtp[0].fmt and dtype.fmt pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt - if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0]))) + if uop is Ops.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0]))) else: casted = [dtypes.as_const(x, dtype) for x in inp[0]] if dtypes.is_int(dtype): @@ -114,18 +114,18 @@ class PythonProgram: elif dtypes.is_float(dtype): casted = [truncate.get(dtype, lambda dt: dt)(x) for x in casted] ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted))) - elif uop is UOps.LOAD: + elif uop is Ops.LOAD: if dtype.count > 1: ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)] else: ul[i] = load(inp) - elif uop is UOps.ASSIGN: + elif uop is Ops.ASSIGN: for j in range(len(inp[0])): inp[0][j] = inp[1][j] ul[i] = inp[0] - elif uop is UOps.GEP: + elif uop is Ops.GEP: assert len(arg) == 1 ul[i] = inp[0][arg[0]] - elif uop is UOps.WMMA: + elif uop is Ops.WMMA: # here are the models for the WMMA instruction on the different hardware def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map): assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread, it has {len(inp[0])}" @@ -180,7 +180,7 @@ class PythonProgram: def c_map(_, elem): return (elem%16, elem//16) ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map) else: raise NotImplementedError(f"unimplemented tensor core {arg}") - elif uop is UOps.ALU: + elif uop is Ops.ALU: assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}" assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}" ul[i] = [exec_alu(arg, dtype, p) for p in zip(*inp)] diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 05420d50df..0a21f01097 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -5,7 +5,7 @@ from typing import Tuple, List, Optional, Dict, Set from tinygrad.helpers import merge_dicts, getenv from tinygrad.shape.view import View, strides_for_shape from tinygrad.dtype import dtypes -from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, split_uop, symbolic_flat, Variable, sint +from tinygrad.ops import UOp, Ops, BinaryOps, graph_rewrite, split_uop, symbolic_flat, Variable, sint @dataclass(frozen=True, order=True) class ShapeTracker: @@ -40,7 +40,7 @@ class ShapeTracker: def reduce(self, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape)) - def to_uop(self) -> UOp: return UOp(UOps.VIEW, dtypes.void, (), self) + def to_uop(self) -> UOp: return UOp(Ops.VIEW, dtypes.void, (), self) def to_indexed_uops(self, _idxs:Optional[List[UOp]]=None) -> Tuple[UOp, UOp]: idx, valid = self.views[-1].to_indexed_uops(_idxs) @@ -75,20 +75,20 @@ class ShapeTracker: ret: List[Optional[sint]] = [None] * len(self.shape) idx, valid = (graph_rewrite(u, symbolic_flat) for u in self.to_indexed_uops()) for c in split_uop(idx, BinaryOps.ADD): - if c.op is UOps.RANGE: ret[c.arg] = 1 - if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[0].op is UOps.RANGE and c.src[1].op is UOps.CONST: ret[c.src[0].arg] = c.src[1].arg - if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[1].op is UOps.RANGE and c.src[0].op is UOps.CONST: ret[c.src[1].arg] = c.src[0].arg - used_ranges = [x.arg for x in idx.sparents if x.op is UOps.RANGE] + if c.op is Ops.RANGE: ret[c.arg] = 1 + if c.op is Ops.ALU and c.arg is BinaryOps.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg + if c.op is Ops.ALU and c.arg is BinaryOps.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg] = c.src[0].arg + used_ranges = [x.arg for x in idx.sparents if x.op is Ops.RANGE] ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)] if not ignore_valid: - for masked_axis in [x.arg for x in valid.sparents if x.op is UOps.RANGE]: ret[masked_axis] = None + for masked_axis in [x.arg for x in valid.sparents if x.op is Ops.RANGE]: ret[masked_axis] = None return tuple(ret) def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1] def axis_is_masked(self, axis:int) -> bool: _, valid = self.to_indexed_uops() - return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).sparents if x.op is UOps.RANGE] + return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).sparents if x.op is Ops.RANGE] def simplify(self) -> ShapeTracker: if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 1ee339ecc4..b37bc82188 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -9,7 +9,7 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, leas from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN from tinygrad.multi import MultiLazyBuffer -from tinygrad.ops import MetaOps, smax, smin, resolve, UOp, UOps, BinaryOps, sint, Variable, SimpleMathTrait +from tinygrad.ops import MetaOps, smax, smin, resolve, UOp, Ops, BinaryOps, sint, Variable, SimpleMathTrait from tinygrad.device import Device, Buffer, BufferOptions from tinygrad.engine.lazy import LazyBuffer from tinygrad.engine.realize import run_schedule @@ -136,7 +136,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported" elif isinstance(data, get_args(ConstType)): data = _metaop(MetaOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data) elif isinstance(data, UOp): - assert data.op is UOps.BIND and data.src[0].op is UOps.DEFINE_VAR and data.src[1].op is UOps.CONST, f"can't create tensor from UOp {data}" + assert data.op is Ops.BIND and data.src[0].op is Ops.DEFINE_VAR and data.src[1].op is Ops.CONST, f"can't create tensor from UOp {data}" data = _metaop(MetaOps.CONST, tuple(), dtype or data.dtype, device, data) elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype) elif isinstance(data, (list, tuple)): @@ -382,9 +382,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method @staticmethod def from_uop(y:UOp, **kwargs) -> Tensor: - if y.op is UOps.BIND: return Tensor(y, **kwargs, requires_grad=False) # this is the only UOp allowed in Tensor - if y.op is UOps.CONST: return Tensor(y.arg, **kwargs, requires_grad=False) - if y.op is UOps.ALU: + if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False) # this is the only UOp allowed in Tensor + if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False) + if y.op is Ops.ALU: if y.arg is BinaryOps.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1]) if y.arg is BinaryOps.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1]) if y.arg is BinaryOps.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1])) diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index ba78f1f792..867ce3f902 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -5,13 +5,13 @@ from urllib.parse import parse_qs, urlparse from dataclasses import asdict, dataclass from typing import Any, Dict, List, Tuple, Optional from tinygrad.helpers import colored, getenv, to_function_name, tqdm, unwrap, word_wrap -from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines +from tinygrad.ops import TrackedRewriteContext, UOp, Ops, lines from tinygrad.codegen.kernel import Kernel -uops_colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.CONST: "#e0e0e0", UOps.VCONST: "#e0e0e0", - UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0", UOps.REDUCE: "#C4A484", - UOps.RANGE: "#c8a0e0", UOps.ASSIGN: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0", UOps.SPECIAL: "#c0c0ff", - UOps.INDEX: "#e8ffa0", UOps.WMMA: "#efefc0", UOps.VIEW: "#C8F9D4", UOps.REDUCE_AXIS: "#f58488"} +uops_colors = {Ops.ALU: "#ffffc0", Ops.LOAD: "#ffc0c0", Ops.STORE: "#c0ffc0", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", + Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE: "#C4A484", + Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", + Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.REDUCE_AXIS: "#f58488"} # ** API spec @@ -50,7 +50,7 @@ def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List for k,ctxs in contexts: name = to_function_name(k.name) if isinstance(k, Kernel) else k for ctx in ctxs: - if ctx.sink.op is UOps.CONST: continue + if ctx.sink.op is Ops.CONST: continue upats = [(upat.location, upat.printable(), tm) for _,_,upat,tm in ctx.matches if upat is not None] if name not in kernels: kernels[name] = [] kernels[name].append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats))) @@ -60,11 +60,11 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]: assert isinstance(x, UOp) graph: Dict[int, Tuple[str, str, List[int], str, str]] = {} for u in x.sparents: - if u.op is UOps.CONST: continue + if u.op is Ops.CONST: continue label = f"{str(u.op)[5:]}{(' '+word_wrap(str(u.arg).replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}" for idx,x in enumerate(u.src): - if x.op is UOps.CONST: label += f"\nCONST{idx} {x.arg:g}" - graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x.op is not UOps.CONST], str(u.arg), uops_colors.get(u.op, "#ffffff")) + if x.op is Ops.CONST: label += f"\nCONST{idx} {x.arg:g}" + graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x.op is not Ops.CONST], str(u.arg), uops_colors.get(u.op, "#ffffff")) return graph def _replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp: if (found:=replaces.get(base)) is not None: return found @@ -88,7 +88,7 @@ def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata) if new_sink is sink: raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}") # update ret data - g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not UOps.CONST]) + g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not Ops.CONST]) g.diffs.append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines()))) g.graphs.append(sink:=new_sink) return g