s/UOps/Ops (#7500)

* s/UOps/Ops [pr]

* fix
This commit is contained in:
George Hotz
2024-11-03 11:26:10 +08:00
committed by GitHub
parent d078dcd0c8
commit c8bf09b7d4
58 changed files with 3003 additions and 3002 deletions

View File

@@ -39,7 +39,7 @@ DEVICE = "CLANG" # NOTE: you can change this!
import struct
from tinygrad.dtype import dtypes
from tinygrad.device import Buffer, Device
from tinygrad.ops import BinaryOps, MetaOps, UOp, UOps
from tinygrad.ops import BinaryOps, MetaOps, UOp, Ops
from tinygrad.shape.shapetracker import ShapeTracker
# allocate some buffers + load in values
@@ -49,14 +49,14 @@ b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struc
# NOTE: a._buf is the same as the return from MallocAllocator.alloc
# describe the computation
buf_1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 1)
buf_2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 2)
ld_1 = UOp(UOps.LOAD, dtypes.int32, (buf_1, ShapeTracker.from_shape((1,)).to_uop()))
ld_2 = UOp(UOps.LOAD, dtypes.int32, (buf_2, ShapeTracker.from_shape((1,)).to_uop()))
buf_1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 1)
buf_2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 2)
ld_1 = UOp(Ops.LOAD, dtypes.int32, (buf_1, ShapeTracker.from_shape((1,)).to_uop()))
ld_2 = UOp(Ops.LOAD, dtypes.int32, (buf_2, ShapeTracker.from_shape((1,)).to_uop()))
alu = ld_1 + ld_2
output_buf = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
st_0 = UOp(UOps.STORE, dtypes.void, (output_buf, ShapeTracker.from_shape((1,)).to_uop(), alu))
s = UOp(UOps.SINK, dtypes.void, (st_0,))
output_buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
st_0 = UOp(Ops.STORE, dtypes.void, (output_buf, ShapeTracker.from_shape((1,)).to_uop(), alu))
s = UOp(Ops.SINK, dtypes.void, (st_0,))
# convert the computation to a "linearized" format (print the format)
from tinygrad.engine.realize import get_kernel, CompiledRunner

View File

@@ -4,7 +4,7 @@
members_order: source
show_labels: false
::: tinygrad.ops.UOps
::: tinygrad.ops.Ops
options:
members: true
members_order: source

View File

@@ -4,7 +4,7 @@ from extra.mcts_search import mcts_search
from examples.mlperf.helpers import get_mlperf_bert_model
from tinygrad import Tensor, Device, dtypes, nn
from tinygrad.codegen.kernel import Kernel
from tinygrad.ops import UOps, sym_infer
from tinygrad.ops import Ops, sym_infer
from tinygrad.device import Compiled
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
@@ -66,7 +66,7 @@ if __name__ == "__main__":
print(f"optimizing for {Device.DEFAULT}")
sched = globals()[f"get_sched_{getenv('MODEL', 'resnet')}"]()
sched = [x for x in sched if x.ast.op is UOps.SINK]
sched = [x for x in sched if x.ast.op is Ops.SINK]
# focus on one kernel
if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1]

View File

@@ -8,7 +8,7 @@ from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GlobalCou
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import get_kernel, run_schedule
from tinygrad.engine.memory import memory_planner
from tinygrad.ops import MetaOps, UOps
from tinygrad.ops import MetaOps, Ops
TIMING = getenv("TIMING")
@@ -41,7 +41,7 @@ if __name__ == "__main__":
print(f"calls {i}:", len(sched))
#run_schedule(sched[:])
sched = memory_planner(sched)
ast_dedup = dedup([si.ast for si in sched if si.ast.op is UOps.SINK])
ast_dedup = dedup([si.ast for si in sched if si.ast.op is Ops.SINK])
srcs = {}
for ast in ast_dedup:
k = get_kernel(Device["CLANG"].renderer, ast)
@@ -82,7 +82,7 @@ if __name__ == "__main__":
for i,si in enumerate(sched):
bufs = [(named_buffers.get(b, f"b{numbered_bufs[b]}"), b) for b in si.bufs]
all_bufs += bufs
if si.ast.op is not UOps.SINK:
if si.ast.op is not Ops.SINK:
print(f"// {si.ast.op}", bufs)
else:
print(f"{srcs[si.ast][0]}({', '.join([x[0] for x in bufs])})")

View File

@@ -19,7 +19,7 @@ from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG, tqdm
from tinygrad.engine.realize import run_schedule, lower_schedule, ExecItem, CompiledRunner
from tinygrad.engine.memory import memory_planner
from tinygrad.engine.schedule import ScheduleItem, create_schedule
from tinygrad.ops import UOps
from tinygrad.ops import Ops
from tinygrad.tensor import _to_np_dtype
Device.DEFAULT = "GPU"
@@ -50,7 +50,7 @@ def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
print(f"{len(schedule)} schedule items depend on the input, {len(schedule_independent)} don't")
# confirm no non-sink metaop in the (non independent) schedule except for the ones that load the input buffers
assert all(si.ast.op is UOps.SINK or out in input_lb for si in schedule for out in si.outputs), "has non SINK ops, can't compile to Thneed"
assert all(si.ast.op is Ops.SINK or out in input_lb for si in schedule for out in si.outputs), "has non SINK ops, can't compile to Thneed"
return schedule, schedule_independent, inputs
def test_vs_onnx(onnx_data, eis:Optional[List[ExecItem]], inputs:Dict[str, Tensor]):
@@ -106,7 +106,7 @@ if __name__ == "__main__":
#exit(0)
schedule, schedule_independent, inputs = get_schedule(onnx_data)
schedule, schedule_input = partition(schedule, lambda x: x.ast.op is UOps.SINK)
schedule, schedule_input = partition(schedule, lambda x: x.ast.op is Ops.SINK)
print(f"{len(schedule_input)} inputs")
run_schedule(schedule_independent)

View File

@@ -1,5 +1,5 @@
from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast
from tinygrad.codegen.kernel import UOps, MemOp, UOp
from tinygrad.codegen.kernel import Ops, MemOp, UOp
from tinygrad.ops import BinaryOps, UnaryOps
from tinygrad.dtype import DType, dtypes
from tinygrad.helpers import DEBUG
@@ -23,7 +23,7 @@ class Register(NamedTuple):
return []
class AssemblyInstruction(NamedTuple):
op: UOps
op: Ops
out: Optional[Register]
vin: List[Union[Register, int, float]]
arg: Any = None
@@ -49,21 +49,21 @@ class AssemblyLanguage:
def render_numnode(self, b) -> Register:
key = ("num", b)
if key not in self.tor: self.ins.append(AssemblyInstruction(UOps.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
if key not in self.tor: self.ins.append(AssemblyInstruction(Ops.LOAD, self.newreg(key, scalar=True, dtype=dtypes.int32), [], b))
return self.tor[key]
def render_alu(self, op, a:Register, b:Union[Register, int, float], dtype=dtypes.int32) -> Register:
key = (op, a, b)
if key not in self.tor:
#if not isinstance(b, Register): b = render_numnode(b)
self.ins.append(AssemblyInstruction(UOps.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
self.ins.append(AssemblyInstruction(Ops.ALU, self.newreg(key, dtype=dtype, scalar=a.scalar and (not isinstance(b, Register) or b.scalar)), [a, b], op))
return self.tor[key]
def render_cast(self, a:Register, new_dtype:DType) -> Register:
if a.dtype == new_dtype: return a
key = (a, new_dtype)
if key not in self.tor:
self.ins.append(AssemblyInstruction(UOps.CAST, self.newreg(key, dtype=new_dtype), [a]))
self.ins.append(AssemblyInstruction(Ops.CAST, self.newreg(key, dtype=new_dtype), [a]))
return self.tor[key]
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.tor[self], NumNode: lambda self, ops, ctx: ctx.render_numnode(self.b),
@@ -87,7 +87,7 @@ class AssemblyLanguage:
if self.supports_load3:
if reg.scalar:
new_reg = self.newreg((reg.nm, 'vec'), dtype=reg.dtype)
self.ins.append(AssemblyInstruction(UOps.ALU, new_reg, [reg], UnaryOps.NOOP))
self.ins.append(AssemblyInstruction(Ops.ALU, new_reg, [reg], UnaryOps.NOOP))
reg = new_reg
return self.tor[args.name], reg, off
reg = self.render_alu(BinaryOps.ADD, self.render_cast(reg, dtypes.uint64), self.tor[args.name], dtype=dtypes.uint64)
@@ -98,91 +98,91 @@ def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]):
lang.ins.clear()
lang.tor.clear()
lang.cnts.clear()
buf_to_dtype = {args:dtype for uop,dtype,_,args,_ in uops if uop == UOps.DEFINE_GLOBAL}
buf_to_dtype = {args:dtype for uop,dtype,_,args,_ in uops if uop == Ops.DEFINE_GLOBAL}
global_size, local_size = [], []
skipload_branch = 0
lang.ins += [AssemblyInstruction(UOps.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
lang.ins += [AssemblyInstruction(Ops.SPECIAL, lang.newreg(buf, dtype=dtypes.uint64, scalar=True), [], buf) for buf in buf_to_dtype]
for u in uops:
uop,dtype,vin,args,_ = u
if uop == UOps.DEFINE_LOCAL:
lang.ins.append(AssemblyInstruction(UOps.DEFINE_LOCAL, None, [], args))
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
elif uop == UOps.LOOP:
if uop == Ops.DEFINE_LOCAL:
lang.ins.append(AssemblyInstruction(Ops.DEFINE_LOCAL, None, [], args))
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.newreg(args[0], dtype=dtypes.uint64), [args[0]], UnaryOps.NOOP))
elif uop == Ops.LOOP:
if args[1] == "global":
for i,var in enumerate(args[0]):
global_size.append(var.max+1)
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
lang.ins.append(AssemblyInstruction(Ops.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"gid{len(args[0])-1-i}"))
elif args[1] == "local":
for i,var in enumerate(args[0]):
local_size.append(var.max+1)
lang.ins.append(AssemblyInstruction(UOps.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
lang.ins.append(AssemblyInstruction(Ops.SPECIAL, lang.newreg(var, dtype=dtypes.int32), [], f"lid{len(args[0])-1-i}"))
else:
for var in args[0]:
if not isinstance(var, NumNode): # TODO: why is this coming through?
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], "$loop_"+var.expr))
elif uop == UOps.ENDLOOP:
lang.ins.append(AssemblyInstruction(Ops.LOAD, lang.newreg(var, dtype=dtypes.int32, scalar=True), [], 0))
lang.ins.append(AssemblyInstruction(Ops.LABEL, None, [], "$loop_"+var.expr))
elif uop == Ops.ENDLOOP:
if args[1] not in ["global", "local", "global+local"]:
for var in reversed(args[0]):
if not isinstance(var, NumNode): # TODO: why is this coming through?
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.tor[var], [lang.tor[var], 1], BinaryOps.ADD))
pred = lang.render_alu(BinaryOps.CMPLT, lang.tor[var], var.max+1, dtypes.bool)
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
lang.ins.append(AssemblyInstruction(Ops.COND_BRANCH, None, [pred], ("$loop_"+var.expr, True)))
elif args[1] == "global+local":
for i, var in enumerate(reversed(args[0])):
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
lang.ins.append(AssemblyInstruction(Ops.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"gid{i}")))
elif args[1] == 'local':
for i, var in enumerate(reversed(args[0])):
lang.ins.append(AssemblyInstruction(UOps.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}")))
elif uop == UOps.CAST:
lang.ins.append(AssemblyInstruction(Ops.ENDLOOP, None, [lang.tor[var]], (var.max+1, f"lid{i}")))
elif uop == Ops.CAST:
# TODO: we should reconsider outputting CAST in the linearizer. these are needless copies
out = lang.newreg(u, dtype)
for i,sr in enumerate(out.subregs()):
lang.ins.append(AssemblyInstruction(UOps.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
elif uop == UOps.ALU:
lang.ins.append(AssemblyInstruction(Ops.ALU, sr, [lang.tor[vin[i]]], UnaryOps.NOOP))
elif uop == Ops.ALU:
out = lang.newreg(u, dtype) if u not in lang.tor else lang.tor[u]
# this is the only thing that can violate SSA
if args in [BinaryOps.CMPLT]:
pred_reg = lang.newreg((u, 'pred'), dtype=dtypes.bool)
lang.ins.append(AssemblyInstruction(UOps.ALU, pred_reg, [lang.tor[x] for x in vin], args))
lang.ins.append(AssemblyInstruction(UOps.CAST, out, [pred_reg], args))
lang.ins.append(AssemblyInstruction(Ops.ALU, pred_reg, [lang.tor[x] for x in vin], args))
lang.ins.append(AssemblyInstruction(Ops.CAST, out, [pred_reg], args))
elif args == BinaryOps.DIV and lang.no_div:
tmp = lang.newreg((u, "rcp"))
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
lang.ins.append(AssemblyInstruction(Ops.ALU, tmp, [lang.tor[vin[1]]], UnaryOps.RECIP))
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [lang.tor[vin[0]], tmp], BinaryOps.MUL))
elif args == UnaryOps.SIN and lang.sin_is_sin2pi:
tmp = lang.newreg((u, "2pi"))
lang.ins.append(AssemblyInstruction(UOps.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [tmp], args))
lang.ins.append(AssemblyInstruction(Ops.ALU, tmp, [lang.tor[vin[0]], 1/(math.pi*2)], BinaryOps.MUL))
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [tmp], args))
else:
lang.ins.append(AssemblyInstruction(UOps.ALU, out, [lang.tor[x] for x in vin], args))
elif uop == UOps.DEFINE_ACC:
lang.ins.append(AssemblyInstruction(Ops.ALU, out, [lang.tor[x] for x in vin], args))
elif uop == Ops.DEFINE_ACC:
reg = lang.newreg(u, dtype=dtype)
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args))
elif uop == UOps.SPECIAL:
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [], args))
elif uop == Ops.SPECIAL:
lang.tor[u] = lang.tor[args]
elif uop == UOps.CONST:
lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(u, dtype=dtype), [], args))
elif uop == UOps.LOAD:
elif uop == Ops.CONST:
lang.ins.append(AssemblyInstruction(Ops.LOAD, lang.newreg(u, dtype=dtype), [], args))
elif uop == Ops.LOAD:
idx, treg, off = lang.addr_w_offset(args)
reg = lang.newreg(u, dtype=dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar)))
if args.valid.min == 0:
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0))
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [], 0))
if args.valid.max == 1:
pred = args.valid.render(lang.render_ops, lang)
lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
lang.ins.append(AssemblyInstruction(Ops.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False)))
if args.valid.max == 1:
# NOTE: you can't compute the index in here, because it assumes it's all available later
lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
lang.ins.append(AssemblyInstruction(Ops.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
if args.valid.min == 0 and args.valid.max == 1:
lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}"))
lang.ins.append(AssemblyInstruction(Ops.LABEL, None, [], f"$skipload_{skipload_branch}"))
skipload_branch += 1
elif uop == UOps.STORE:
elif uop == Ops.STORE:
if args is None:
lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP))
lang.ins.append(AssemblyInstruction(Ops.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP))
else:
idx, treg, off = lang.addr_w_offset(args)
lang.ins.append(AssemblyInstruction(UOps.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
lang.ins.append(AssemblyInstruction(Ops.STORE, None, [idx, lang.tor[vin[0]]] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None)))
if DEBUG >= 4:
for tins in lang.ins: print(tins)

View File

@@ -3,7 +3,7 @@ from platform import system
from typing import Tuple, Dict, List, Optional
from tinygrad import dtypes
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.codegen.kernel import UOps, UOp
from tinygrad.codegen.kernel import Ops, UOp
from tinygrad.helpers import CI
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
@@ -19,7 +19,7 @@ class ARM64Language(AssemblyLanguage): pass
def specialize_to_arm64(fn_nm, asm):
var_size = 16
prev_uop:Optional[UOps] = None
prev_uop:Optional[Ops] = None
ins = []
x_regs = ['x' + str(i) for i in reversed(range(12))]
s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16]
@@ -81,7 +81,7 @@ def specialize_to_arm64(fn_nm, asm):
ins.append(f"mov x15, {mem_vars[v.nm]}")
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
if uop == UOps.SPECIAL:
if uop == Ops.SPECIAL:
if arg.startswith('data'):
# data 8 to n into the stack
if int(arg[4:]) >= 8:
@@ -90,7 +90,7 @@ def specialize_to_arm64(fn_nm, asm):
else:
ins.append(f"mov {rtor[out.nm]}, #0")
ins.append(f"loop_{arg}:")
elif uop == UOps.CAST:
elif uop == Ops.CAST:
if arg == BinaryOps.CMPLT:
if rtor[out.nm][0] == 's':
mov_imm(0.0, 's0')
@@ -102,7 +102,7 @@ def specialize_to_arm64(fn_nm, asm):
ins.append(f"csel {rtor[out.nm]}, x15, x14, lt")
else:
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
elif uop == UOps.ALU:
elif uop == Ops.ALU:
if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
@@ -136,7 +136,7 @@ def specialize_to_arm64(fn_nm, asm):
ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}")
else:
ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
elif uop == UOps.LOAD:
elif uop == Ops.LOAD:
if arg.__class__ in (int, float):
mov_imm(arg, rtor[out.nm])
else:
@@ -146,20 +146,20 @@ def specialize_to_arm64(fn_nm, asm):
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
elif uop == UOps.STORE:
elif uop == Ops.STORE:
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
ins.append(f"mov x15, #{arg[0]}")
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
elif uop == UOps.COND_BRANCH:
elif uop == Ops.COND_BRANCH:
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
if prev_uop == UOps.LOAD:
if prev_uop == Ops.LOAD:
ins.append(f"cmp {rtor[vin[0].nm]}, #0")
ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}")
elif uop == UOps.LABEL:
elif uop == Ops.LABEL:
ins.append(f"{arg[1:]}:")
elif uop == UOps.ENDLOOP:
elif uop == Ops.ENDLOOP:
mov_imm(arg[0], "x15")
ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1")
ins.append(f"cmp {rtor[vin[0].nm]}, x15")

View File

@@ -1,7 +1,7 @@
from typing import List
import struct
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
from tinygrad.codegen.kernel import UOps, UOp
from tinygrad.codegen.kernel import Ops, UOp
from tinygrad import dtypes
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.runtime.ops_cuda import arch
@@ -37,11 +37,11 @@ def specialize_to_ptx(lang, function_name):
UnaryOps.SIN: "sin.approx", UnaryOps.LOG2: "lg2.approx", UnaryOps.EXP2: "ex2.approx.ftz",
TernaryOps.MULACC: "fma.rn", TernaryOps.WHERE: "selp"}
for uop, out, vin, arg in lang.ins:
if uop == UOps.ENDLOOP:
if uop == Ops.ENDLOOP:
ins.append("bar.sync 0;")
elif uop == UOps.DEFINE_LOCAL:
elif uop == Ops.DEFINE_LOCAL:
ins.append(f".shared .align 4 .b8 {arg[0]}[{arg[1]*4}];")
elif uop == UOps.SPECIAL:
elif uop == Ops.SPECIAL:
if arg.startswith('data'):
param_cnt += 1
ins.append(f"ld.param.u64 {out}, [{arg}];")
@@ -51,7 +51,7 @@ def specialize_to_ptx(lang, function_name):
ins.append(f"mov.u32 {out}, %ctaid.{'xyz'[int(arg[3:])]};")
elif arg.startswith('lid'):
ins.append(f"mov.u32 {out}, %tid.{'xyz'[int(arg[3:])]};")
elif uop == UOps.ALU:
elif uop == Ops.ALU:
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
ins.append(f"and.pred {out}, {', '.join(str(x) for x in vin)};")
else:
@@ -64,7 +64,7 @@ def specialize_to_ptx(lang, function_name):
ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};")
vin = vin[1:] + [reg]
ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};")
elif uop == UOps.LOAD:
elif uop == Ops.LOAD:
if arg.__class__ in (int, float):
ins.append(f"mov.{dtype_to_nvtype[out.dtype]} {out}, {'0f'+float_to_hex(arg) if dtypes.is_float(out.dtype) else int(arg)};")
elif arg[2] is not None and (arg[2] == dtypes.bool or arg[2] != out.dtype):
@@ -74,7 +74,7 @@ def specialize_to_ptx(lang, function_name):
render_cast(ins, reg, out)
else:
ins.append(f"ld.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} {out}, [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}];")
elif uop == UOps.STORE:
elif uop == Ops.STORE:
if ptx_needs_cast(dtypes.float if arg[2] is None else arg[2], vin[1].dtype) or arg[2] == dtypes.bool:
if arg[2] == dtypes.bool != vin[1].dtype:
prereg = lang.newreg((vin[1],'bool'), dtype=dtypes.bool)
@@ -85,11 +85,11 @@ def specialize_to_ptx(lang, function_name):
ins.append(f"st.{arg[1]}.{dtype_to_nvtype['bits16' if arg[2] == dtypes.float16 else dtypes.uint8 if arg[2] == dtypes.bool else dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {reg};")
else:
ins.append(f"st.{arg[1]}.{dtype_to_nvtype[dtypes.float if arg[2] is None else arg[2]]} [{vin[0]}{f'+{arg[0]}' if arg[0] is not None else ''}], {vin[1]};")
elif uop == UOps.CAST:
elif uop == Ops.CAST:
render_cast(ins, vin[0], out)
elif uop == UOps.LABEL:
elif uop == Ops.LABEL:
ins.append(f"{arg}:")
elif uop == UOps.COND_BRANCH:
elif uop == Ops.COND_BRANCH:
ins.append(f"@{'!' if not arg[1] else ''}{vin[0]} bra {arg[0]};")
ins_prefix = [".version 7.8", ".target " + arch(), ".address_size 64",

View File

@@ -2,7 +2,7 @@ import yaml
from typing import Tuple, Set, Dict
from tinygrad import dtypes
from tinygrad.codegen.assembly import AssemblyCodegen, Register
from tinygrad.codegen.kernel import UOps
from tinygrad.codegen.kernel import Ops
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
from tinygrad.runtime.ops_gpu import ROCM_LLVM_PATH
@@ -61,7 +61,7 @@ class RDNACodegen(AssemblyCodegen):
def reg_out(x):
return rtor[x]
for uop, out, vin, arg in asm:
if uop == UOps.DEFINE_REGISTER:
if uop == Ops.DEFINE_REGISTER:
if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float.vec(4)]:
for i in range(arg[2]):
# TODO: Re-use gaps created by this to avoid wasting registers
@@ -86,7 +86,7 @@ class RDNACodegen(AssemblyCodegen):
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
else:
raise NotImplementedError("DEFINE_REGISTER not implemented for arg: ", arg)
elif uop == UOps.SPECIAL:
elif uop == Ops.SPECIAL:
if arg.startswith('buf'):
i = int(arg[3:])
ins.append(f's_load_b64 {reg_out(out)}, s[0:1], {i*8}')
@@ -106,7 +106,7 @@ class RDNACodegen(AssemblyCodegen):
pend_regs.clear()
ins.append(f'v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}')
ins.append(f'v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}')
elif uop == UOps.CONST:
elif uop == Ops.CONST:
if arg == float('inf'): arg = "0x7f800000"
elif arg == float('-inf'): arg = "0xff800000"
if out.dtype == dtypes.float.vec(4):
@@ -114,7 +114,7 @@ class RDNACodegen(AssemblyCodegen):
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}")
else:
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}")
elif uop == UOps.ALU:
elif uop == Ops.ALU:
if arg in [BinaryOps.CMPLT]:
ins.append(f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
else:
@@ -127,7 +127,7 @@ class RDNACodegen(AssemblyCodegen):
ins.append(f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
else:
ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
elif uop == UOps.LOAD:
elif uop == Ops.LOAD:
if out.scalar:
# swap arg order
ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}')
@@ -135,13 +135,13 @@ class RDNACodegen(AssemblyCodegen):
ins.append(f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
pend_regs.add(out)
for r in out.subregs(): pend_regs.add(r)
elif uop == UOps.STORE:
elif uop == Ops.STORE:
ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
elif uop == UOps.LABEL:
elif uop == Ops.LABEL:
ins.append(f"{arg}:")
elif uop == UOps.COND_BRANCH:
elif uop == Ops.COND_BRANCH:
ins.append(f"s_cbranch_scc{'1' if arg[1] else '0'} {arg[0]}")
elif uop == UOps.CAST:
elif uop == Ops.CAST:
if vin[0].dtype == dtypes.bool:
if out.dtype == dtypes.float32:
ins.append(f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}")

View File

@@ -2,7 +2,7 @@ from typing import Dict, List, Final, Callable, DefaultDict
from collections import defaultdict
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Op
from tinygrad.helpers import DType, PtrDType, dtypes, ImageDType, DEBUG, getenv
from tinygrad.codegen.kernel import UOp, UOps
from tinygrad.codegen.kernel import UOp, Ops
from triton.compiler import compile as triton_compile
import linecache
import math
@@ -75,32 +75,32 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
def int_div(x,y): return f"({x}//{y})" if y != '0' else f"{x}*tl.where({x}==0, float('nan'), float('inf'))"
for u in uops:
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
if uop == UOps.LOOP:
if uop == Ops.LOOP:
kk(f"for {ssa(u, 'ridx')} in range({vin[0].arg}, {r[vin[1]]}):")
depth += 1
elif uop == UOps.END: depth -= 1
elif uop == UOps.ALU:
elif uop == Ops.END: depth -= 1
elif uop == Ops.ALU:
assert dtype is not None
val = code_for_op[args](*[r[x] for x in vin])
if child_count[u] <=1 or dtypes.is_int(dtype): r[u] = int_div(*[r[x] for x in vin]) if args == BinaryOps.DIV and dtypes.is_int(dtype) else val
else: kk(f"{ssa(u, 'alu')} = ({val})")
elif uop == UOps.LOAD:
elif uop == Ops.LOAD:
assert dtype is not None
if len(vin) == 2: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.load({r[vin[0]]} + { fill_dims_for_idx(r[vin[1]], dims)}, mask = {render_valid(valid)})', dtype)}")
else: kk(f"{ssa(u, 'val')} = {render_cast(f'tl.where({r[vin[2]]}, tl.load({r[vin[0]]}+{fill_dims_for_idx(r[vin[1]],dims)} , mask={render_valid(valid+[r[vin[2]]])}), 0.0)', dtype)}")
elif uop == UOps.DEFINE_ACC: kk(f"{ssa(u, 'acc')} = {define_scalar(local_size, dtype, args).replace('//', '/')}")
elif uop == UOps.CONST: r[u] = define_scalar([], dtype, args)
elif uop == UOps.ASSIGN:
elif uop == Ops.DEFINE_ACC: kk(f"{ssa(u, 'acc')} = {define_scalar(local_size, dtype, args).replace('//', '/')}")
elif uop == Ops.CONST: r[u] = define_scalar([], dtype, args)
elif uop == Ops.ASSIGN:
kk(f"{r[vin[0]]} = {r[vin[1]].replace('//', '/')}")
r[u] = r[vin[0]]
elif uop == UOps.STORE:
elif uop == Ops.STORE:
assert not isinstance(dtype, ImageDType), "unimplemented: image store"
kk(f"{'if '+r[vin[3]]+': ' if len(vin)>3 else ''}tl.store({r[vin[0]]} + {r[vin[1]]}, {r[vin[2]].replace('//', '/')}, mask = {render_valid(valid)}) ")
elif uop == UOps.DEFINE_GLOBAL:
elif uop == Ops.DEFINE_GLOBAL:
bufs.append(args)
signatures.append("*" if isinstance(dtype, PtrDType) else "" + signature_dtypes[dtype])
r[u] = args
elif uop == UOps.SPECIAL:
elif uop == Ops.SPECIAL:
dims.append(args[1])
valid.append(f"{args[1]}<{get_max(args[2])}")
if args[1].startswith("g"): kk(f"{args[1]} = tl.program_id({args[0]}) # {args[2]}")
@@ -108,7 +108,7 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
kk(f"{args[1]} = tl.arange({0}, {next_power_of_2(args[2])})")
local_size.append(args[2])
r[u] = args[1]
elif uop == UOps.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype, isinstance(args, tuple) and args[1])
elif uop == Ops.CAST and dtype is not None: r[u] = render_cast(r[vin[0]], dtype, isinstance(args, tuple) and args[1])
else: raise NotImplementedError(f"unimplemented: {uop}")
prg = f"import triton\nimport triton.language as tl\ntl.core.TRITON_MAX_TENSOR_NUMEL = float('inf')\n@triton.jit\ndef {function_name}("+','.join(bufs)+"):\n"

View File

@@ -2,11 +2,12 @@
from typing import Tuple
from tinygrad import Variable
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.ops import UOp, UOps, KernelInfo, TernaryOps, BinaryOps, UnaryOps, ReduceOps, MetaOps
from tinygrad.ops import UOp, Ops, KernelInfo, TernaryOps, BinaryOps, UnaryOps, ReduceOps, MetaOps
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
inf, nan = float('inf'), float('nan')
UOps = Ops
# kernel unpacker
from tinygrad.codegen.kernel import Kernel

View File

@@ -4,7 +4,7 @@ from collections import defaultdict
from typing import List, Tuple, DefaultDict
from extra.optimization.helpers import load_worlds, ast_str_to_ast
from tinygrad.helpers import prod, tqdm
from tinygrad.ops import UOp, UOps
from tinygrad.ops import UOp, Ops
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.ops import sym_infer, Node
@@ -136,7 +136,7 @@ def test_rebuild(st: ShapeTracker):
assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}"
def test_rebuild_bufferop_st(ast:UOp):
if ast.op is UOps.SHAPETRACKER:
if ast.op is Ops.SHAPETRACKER:
test_rebuild(ast.arg)
for src in ast.src: test_rebuild_bufferop_st(src)

View File

@@ -2,7 +2,7 @@ from typing import List
from extra.models.resnet import ResNet50
from tinygrad import Tensor, Device, nn
from tinygrad.helpers import Profiling, Timing, getenv, BEAM, NOOPT, DEBUG, Context, ansilen
from tinygrad.ops import UOps
from tinygrad.ops import Ops
from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
from tinygrad.codegen.linearize import linearize_uop
@@ -27,7 +27,7 @@ if __name__ == "__main__":
sched = out.schedule()
if not SCHEDULE_ONLY:
asts = list({x.ast.key:x.ast for x in sched if x.ast.op is UOps.SINK}.values())
asts = list({x.ast.key:x.ast for x in sched if x.ast.op is Ops.SINK}.values())
if (restrict_kernel := getenv("RESTRICT_KERNEL", -1)) != -1: asts = asts[restrict_kernel:restrict_kernel+1]
kernels: List[Kernel] = []
with Timing(f"***** model opts({len(asts):2d}) in "):

View File

@@ -2,7 +2,7 @@
import unittest
from tinygrad import Device
from tinygrad.ops import UOp, UOps, BinaryOps, UnaryOps
from tinygrad.ops import UOp, Ops, BinaryOps, UnaryOps
from tinygrad.engine.search import Opt, OptOps
from tinygrad.dtype import dtypes
from tinygrad.shape.shapetracker import ShapeTracker
@@ -13,40 +13,40 @@ class TestOpenpilotValidhack(unittest.TestCase):
def test_valid_removal(self):
Device.DEFAULT = "GPU"
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((64, 1024, 4)), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 4096, 32, 0, 0, 4, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
x5:=UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8, 9, 10)), src=(
UOp(UOps.CAST, dtypes.float, arg=None, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((128, 768, 4)), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 3, 1, 4, 4, 130, 4, 258), strides=(0, 0, 0, 0, 0, 4, 0, 1, 0, 3072, 0, 12), offset=-3084, mask=((0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 3), (0, 1), (0, 4), (0, 4), (1, 129), (0, 4), (1, 257)), contiguous=False), View(shape=(1, 64, 128, 1, 1, 8, 4, 3, 4, 3, 3), strides=(0, 2064, 2, 0, 0, 0, 0, 2146560, 536640, 135192, 259), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(UOps.CAST, dtypes.float, arg=None, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((8, 108, 4)), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 3, 4, 3, 3), strides=(0, 0, 0, 0, 0, 432, 1, 48, 4, 144, 16), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 4, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
x19:=UOp(UOps.CONST, dtypes.float, arg=0.0, src=(
x20:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.CONST, dtypes.float, arg=1.0, src=(
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((64, 1024, 4)), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 4096, 32, 0, 0, 4, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
x5:=UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8, 9, 10)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((128, 768, 4)), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 3, 1, 4, 4, 130, 4, 258), strides=(0, 0, 0, 0, 0, 4, 0, 1, 0, 3072, 0, 12), offset=-3084, mask=((0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 3), (0, 1), (0, 4), (0, 4), (1, 129), (0, 4), (1, 257)), contiguous=False), View(shape=(1, 64, 128, 1, 1, 8, 4, 3, 4, 3, 3), strides=(0, 2064, 2, 0, 0, 0, 0, 2146560, 536640, 135192, 259), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((8, 108, 4)), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 3, 4, 3, 3), strides=(0, 0, 0, 0, 0, 432, 1, 48, 4, 144, 16), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 4, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
x19:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=(
x20:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.CONST, dtypes.float, arg=1.0, src=(
x20,)),
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.ALU, dtypes.float, arg=UnaryOps.EXP2, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=UnaryOps.EXP2, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
x5,
UOp(UOps.CONST, dtypes.float, arg=1.4426950408889634, src=(
UOp(Ops.CONST, dtypes.float, arg=1.4426950408889634, src=(
x20,)),)),)),
x29:=UOp(UOps.CONST, dtypes.float, arg=-1.0, src=(
x29:=UOp(Ops.CONST, dtypes.float, arg=-1.0, src=(
x20,)),)),)),
x19,)),
x29,)),)),)),))
@@ -62,51 +62,51 @@ class TestOpenpilotValidhack(unittest.TestCase):
def test_const_idx(self):
Device.DEFAULT = "GPU"
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((10, 128, 4)), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 512, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.CAST, dtypes.float, arg=None, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((1, 128, 4)), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=0, mask=((0, 1), (0, 1), (0, 512)), contiguous=False),)), src=()),)),
UOp(UOps.CAST, dtypes.float, arg=None, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
x18:=UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=48128, mask=((0, 1), (1, 2), (0, 512)), contiguous=False),)), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((10, 128, 4)), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 512, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((1, 128, 4)), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=0, mask=((0, 1), (0, 1), (0, 512)), contiguous=False),)), src=()),)),
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
x18:=UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=48128, mask=((0, 1), (1, 2), (0, 512)), contiguous=False),)), src=()),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
x18,
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=45568, mask=((0, 1), (2, 3), (0, 512)), contiguous=False),)), src=()),)),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=45568, mask=((0, 1), (2, 3), (0, 512)), contiguous=False),)), src=()),)),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
x18,
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=43008, mask=((0, 1), (3, 4), (0, 512)), contiguous=False),)), src=()),)),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=43008, mask=((0, 1), (3, 4), (0, 512)), contiguous=False),)), src=()),)),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
x18,
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=40448, mask=((0, 1), (4, 5), (0, 512)), contiguous=False),)), src=()),)),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=40448, mask=((0, 1), (4, 5), (0, 512)), contiguous=False),)), src=()),)),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
x18,
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=37888, mask=((0, 1), (5, 6), (0, 512)), contiguous=False),)), src=()),)),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=37888, mask=((0, 1), (5, 6), (0, 512)), contiguous=False),)), src=()),)),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
x18,
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=35328, mask=((0, 1), (6, 7), (0, 512)), contiguous=False),)), src=()),)),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=35328, mask=((0, 1), (6, 7), (0, 512)), contiguous=False),)), src=()),)),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
x18,
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=32768, mask=((0, 1), (7, 8), (0, 512)), contiguous=False),)), src=()),)),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=32768, mask=((0, 1), (7, 8), (0, 512)), contiguous=False),)), src=()),)),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
x18,
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=30208, mask=((0, 1), (8, 9), (0, 512)), contiguous=False),)), src=()),)),)),)),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((1, 128, 4)), arg=3, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=0, mask=((0, 1), (9, 10), (0, 512)), contiguous=False),)), src=()),)),)),)),)),))
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=30208, mask=((0, 1), (8, 9), (0, 512)), contiguous=False),)), src=()),)),)),)),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((1, 128, 4)), arg=3, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=0, mask=((0, 1), (9, 10), (0, 512)), contiguous=False),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.NOLOCALS, axis=None, amt=None)]
kernel = Kernel(ast)

View File

@@ -25,7 +25,7 @@ from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.engine.search import get_kernel_actions, bufs_from_lin
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG, Timing
from tinygrad.ops import UnaryOps, UOp, UOps
from tinygrad.ops import UnaryOps, UOp, Ops
from test.helpers import is_dtype_supported
def on_linearizer_will_run(): pass
@@ -252,7 +252,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2, opts_list=None):
def _is_simple(lin: Kernel) -> bool:
if len(lin.ast.src) > 1: return False
ast:UOp = lin.ast.src[0]
if ast.src[0].arg is UnaryOps.CAST and ast.src[0].src[0].op is UOps.LOAD: return True
if ast.src[0].arg is UnaryOps.CAST and ast.src[0].src[0].op is Ops.LOAD: return True
return False
if __name__ == "__main__":

View File

@@ -2,7 +2,7 @@ import sys, time, logging, difflib
from typing import Callable, Optional, Tuple
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.ops import UOp, UOps, sint
from tinygrad.ops import UOp, Ops, sint
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.tensor import _to_np_dtype
from tinygrad.engine.realize import Runner
@@ -69,7 +69,7 @@ def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)):
def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional[ShapeTracker]=None, st_src:Optional[Tuple[UOp]]=None) -> UOp:
if st_src is None:
st_src = (st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)
return UOp(UOps.VALID, dtypes.bool, st_src).where(UOp.const(dtype, val), UOp.const(dtype, 0))
return UOp(Ops.VALID, dtypes.bool, st_src).where(UOp.const(dtype, val), UOp.const(dtype, 0))
def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]:
st = time.perf_counter_ns()
@@ -77,7 +77,7 @@ def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]:
return ret, (time.perf_counter_ns()-st)*1e-6
def eval_uop(uop:UOp):
g = UOp(UOps.DEFINE_GLOBAL, uop.dtype.ptr(), arg=0, src=())
g = UOp(Ops.DEFINE_GLOBAL, uop.dtype.ptr(), arg=0, src=())
rw = full_graph_rewrite(UOp.store(g, UOp.const(dtypes.int, 0), uop).sink(), PythonRenderer)
prog = PythonProgram("run", PythonCompiler().compile(PythonRenderer().render("run", linearize_uop(rw))))
buf = PythonAllocator().alloc(uop.dtype.itemsize)

View File

@@ -1,6 +1,6 @@
import unittest, math
from tinygrad import Tensor, Device, dtypes
from tinygrad.ops import UOps
from tinygrad.ops import Ops
from tinygrad.engine.schedule import create_schedule
from tinygrad.helpers import CI
import numpy as np
@@ -9,7 +9,7 @@ from test.helpers import is_dtype_supported
def _check_ast_count(desired_count:int, t:Tensor):
# NOTE: this has side effect because everything can be scheduled only once
schedule = create_schedule(t.lazydata.lbs)
asts = [s for s in schedule if s.ast.op is UOps.SINK]
asts = [s for s in schedule if s.ast.op is Ops.SINK]
assert len(asts) == desired_count
class TestUnaryOpsConstFolding(unittest.TestCase):

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python
import unittest
from tinygrad.ops import UOps
from tinygrad.ops import Ops
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d
from tinygrad.engine.schedule import create_schedule
@@ -15,9 +15,9 @@ class TestConvShapetracker(unittest.TestCase):
# first run to init the weights, they are scheduled.
create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata])
# run it again to get the kernels
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata]) if si.ast.op is UOps.SINK]
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata]) if si.ast.op is Ops.SINK]
assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}"
for st in [x.st_arg for x in sched[0].ast.parents if x.op is UOps.LOAD]:
for st in [x.st_arg for x in sched[0].ast.parents if x.op is Ops.LOAD]:
assert len(st.views) == 1
def test_conv_2x2_backward_one_view(self):
@@ -26,7 +26,7 @@ class TestConvShapetracker(unittest.TestCase):
conv(X).mean().backward()
si = X.grad.schedule()[-1]
print(si)
ldb = [x for x in si.ast.parents if x.op is UOps.LOAD][0]
ldb = [x for x in si.ast.parents if x.op is Ops.LOAD][0]
st: ShapeTracker = ldb.st_arg.simplify()
# NOTE: st.real_size() is broken
print(si.inputs[0].size)

View File

@@ -8,7 +8,7 @@ from tinygrad.dtype import DType
from tinygrad.helpers import CI, getenv
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
from tinygrad.ops import UnaryOps, UOps
from tinygrad.ops import UnaryOps, Ops
from tinygrad.tensor import _to_np_dtype
from test.helpers import is_dtype_supported
import pytest
@@ -79,7 +79,7 @@ def universal_test_unary(a, dtype, op):
np.testing.assert_allclose(tensor_value, numpy_value, atol=1e-3, rtol=1e-2)
else: np.testing.assert_equal(tensor_value, numpy_value)
if op[0] != Tensor.reciprocal: # reciprocal is not supported in most backends
op = [x for x in ast.parents if x.op is UOps.ALU and x.arg in UnaryOps][0]
op = [x for x in ast.parents if x.op is Ops.ALU and x.arg in UnaryOps][0]
assert op.dtype == dtype
def universal_test_cast(a, in_dtype, dtype):

View File

@@ -2,7 +2,7 @@
import numpy as np
import unittest
from tinygrad import Tensor, Device, dtypes
from tinygrad.ops import UOps
from tinygrad.ops import Ops
from tinygrad.engine.lazy import LazyBuffer, MetaOps
from tinygrad.engine.schedule import create_schedule
@@ -75,7 +75,7 @@ class TestReduceOp(unittest.TestCase):
a = a.sum()
sched = create_schedule([a.lazydata])
assert len(sched) == 1
self.assertIs(sched[0].ast.src[0].src[2].op, UOps.REDUCE_AXIS)
self.assertIs(sched[0].ast.src[0].src[2].op, Ops.REDUCE_AXIS)
def test_split_reduce_kernel_dim0(self):
a = Tensor.rand(256, 255).realize()
@@ -83,7 +83,7 @@ class TestReduceOp(unittest.TestCase):
sched = create_schedule([a.lazydata])
assert len(sched) == 2
for s in sched:
self.assertIs(s.ast.src[0].src[2].op, UOps.REDUCE_AXIS)
self.assertIs(s.ast.src[0].src[2].op, Ops.REDUCE_AXIS)
def test_split_reduce_kernel_dim1(self):
a = Tensor.rand(255, 256).realize()
@@ -91,7 +91,7 @@ class TestReduceOp(unittest.TestCase):
sched = create_schedule([a.lazydata])
assert len(sched) == 2
for s in sched:
self.assertIs(s.ast.src[0].src[2].op, UOps.REDUCE_AXIS)
self.assertIs(s.ast.src[0].src[2].op, Ops.REDUCE_AXIS)
class TestView(unittest.TestCase):
def test_all_masked_out(self):

File diff suppressed because it is too large Load Diff

View File

@@ -5,7 +5,7 @@
import unittest
from test.helpers import ast_const
from tinygrad import Device, dtypes
from tinygrad.ops import UOp, UOps, BinaryOps, TernaryOps
from tinygrad.ops import UOp, Ops, BinaryOps, TernaryOps
from tinygrad.helpers import getenv
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.engine.search import Opt, OptOps
@@ -14,26 +14,26 @@ from tinygrad.codegen.kernel import Kernel
class TestLinearizerDumb(unittest.TestCase):
@unittest.skipUnless(Device.DEFAULT == "METAL", "only tested on METAL")
def test_unmerged_ifs(self):
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MAX, src=(
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(UOps.CAST, dtypes.half, arg=None, src=(
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
UOp(UOps.CAST, dtypes.float, arg=None, src=(
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 512, 4, 9, 4, 9), strides=(0, 25088, 0, 49, 0, 7, 0, 1), offset=-8, mask=((0, 1), (0, 64), (0, 1), (0, 512), (0, 4), (1, 8), (0, 4), (1, 8)), contiguous=False), View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(663552, 0, 0, 36, 1, 1296, 360, 10), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MAX, src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 512, 4, 9, 4, 9), strides=(0, 25088, 0, 49, 0, 7, 0, 1), offset=-8, mask=((0, 1), (0, 64), (0, 1), (0, 512), (0, 4), (1, 8), (0, 4), (1, 8)), contiguous=False), View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(663552, 0, 0, 36, 1, 1296, 360, 10), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),
ast_const(dtypes.half, 0.9999950000374996, st_src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
ast_const(dtypes.half, 0.0, st_src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),))
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),))
opts = [Opt(op=OptOps.TC, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UNROLL, axis=1, amt=0)]
k = Kernel(ast, opts=Device["METAL"].renderer)
k.required_optimizations()
@@ -43,88 +43,88 @@ class TestLinearizerDumb(unittest.TestCase):
Device[Device.DEFAULT].compiler.compile_cached(prg.src)
gate_count = len([x for x in prg.src.splitlines() if "if" in x])
assert gate_count == 1, f"must have only one gate {gate_count} != 1"
assert len([u for u in k.uops if u.op is UOps.IF]) == 1, "must have a single IF"
assert len([u for u in k.uops if u.op is Ops.IF]) == 1, "must have a single IF"
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
def test_max_simplify_and_cancel(self):
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
UOp(UOps.CAST, dtypes.int, arg=None, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
UOp(Ops.CAST, dtypes.int, arg=None, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
ast_const(dtypes.bool, True, st_src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
ast_const(dtypes.int, -1, st_src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1001, 1999), strides=(0, 0), offset=0, mask=((0, 1001), (999, 1999)), contiguous=False), View(shape=(1000, 1000), strides=(1, 2000), offset=0, mask=None, contiguous=False))), src=()),)),)),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1001, 1999), strides=(0, 0), offset=0, mask=((0, 1001), (999, 1999)), contiguous=False), View(shape=(1000, 1000), strides=(1, 2000), offset=0, mask=None, contiguous=False))), src=()),)),)),
ast_const(dtypes.int, 1000, st_src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8)]
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
k.required_optimizations()
for opt in opts: k.apply_opt(opt)
prg = k.to_program()
print(prg.src)
assert prg.uops is not None and not any(uop.op is UOps.ALU and uop.arg is BinaryOps.MAX for uop in prg.uops), "leftover MAX"
assert prg.uops is not None and not any(uop.op is Ops.ALU and uop.arg is BinaryOps.MAX for uop in prg.uops), "leftover MAX"
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
def test_expander_new_srcs(self):
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0)]
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
k.required_optimizations()
for opt in opts: k.apply_opt(opt)
prg = k.to_program()
print(prg.src)
if_uops = [u for u in k.uops if u.op is UOps.IF]
if_uops = [u for u in k.uops if u.op is Ops.IF]
self.assertIn(len(if_uops), {1,3})
conditions = if_uops[0].src[0].sparents
self.assertLessEqual(len(conditions), 9)
# this was a bug in embedding, someday we should fold this anyway
def test_llama_embedding(self):
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.CAST, dtypes.half, arg=None, src=(
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(UOps.CAST, dtypes.float, arg=None, src=(
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(UOps.CAST, dtypes.half, arg=None, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=(
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.CAST, dtypes.half, arg=None, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=(
ast_const(dtypes.int, 1, st_src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32001, 63999), strides=(0, 0), offset=0, mask=((0, 32001), (31999, 63999)), contiguous=False), View(shape=(4096, 32000, 32000), strides=(0, 1, 64000), offset=0, mask=None, contiguous=False))), src=()),)),)),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32001, 63999), strides=(0, 0), offset=0, mask=((0, 32001), (31999, 63999)), contiguous=False), View(shape=(4096, 32000, 32000), strides=(0, 1, 64000), offset=0, mask=None, contiguous=False))), src=()),)),)),
ast_const(dtypes.int, -1, st_src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
ast_const(dtypes.bool, True, st_src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(1, 4096, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(1, 4096, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
prg = k.to_program()
print(prg.src)
@@ -132,88 +132,88 @@ class TestLinearizerDumb(unittest.TestCase):
# from process replay https://github.com/tinygrad/tinygrad/actions/runs/10389229290/job/28766762085#step:18:6490
@unittest.expectedFailure
def test_unaligns_idxs(self):
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.CAST, dtypes.float, arg=None, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(UOps.LOAD, dtypes.long, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(UOps.CAST, dtypes.long, arg=None, src=(
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.LOAD, dtypes.long, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.CAST, dtypes.long, arg=None, src=(
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
ast_const(dtypes.bool, True, st_src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=3)]
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
for opt in opts: k.apply_opt(opt)
prg = k.to_program()
print(prg.src)
load_idxs = [x.src[1] for x in k.uops if x.op is UOps.LOAD and x.src[0].arg == 3]
load_idxs = [x.src[1] for x in k.uops if x.op is Ops.LOAD and x.src[0].arg == 3]
assert load_idxs[0] < load_idxs[1], f"first loaded idx {load_idxs[0].arg} then {load_idxs[1].arg}!"
@unittest.expectedFailure
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need float4")
def test_unrolled_float4_align(self):
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(UOps.LOAD, dtypes.long, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(Ops.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(Ops.LOAD, dtypes.long, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
ast_const(dtypes.long, -1, st_src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
ast_const(dtypes.bool, True, st_src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
ast_const(dtypes.float, 0.0, st_src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),))
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0)]
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
for opt in opts: k.apply_opt(opt)
prg = k.to_program()
print(prg.src)
load_idxs = [x.src[1] for x in k.uops if x.op is UOps.LOAD and x.src[0].arg == 2]
load_idxs = [x.src[1] for x in k.uops if x.op is Ops.LOAD and x.src[0].arg == 2]
assert load_idxs[0] < load_idxs[1], f"first loaded idx {load_idxs[0].arg} then {load_idxs[1].arg}!"
@unittest.expectedFailure
@unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "need float4")
@unittest.skipIf(getenv("PTX"), "this is somehow correct in PTX")
def test_upcasted_stores_out_of_order(self):
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6,)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(0, 0, 0, 0, 0, 0, 1, 0, 4, 48, 16), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6,)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(0, 0, 0, 0, 0, 0, 1, 0, 4, 48, 16), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=0)]
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
for opt in opts: k.apply_opt(opt)
prg = k.to_program()
print(prg.src)
store_idxs = [x.src[1] for x in k.uops if x.op is UOps.STORE]
store_idxs = [x.src[1] for x in k.uops if x.op is Ops.STORE]
for i in range(len(store_idxs) - 1):
first_bounds = store_idxs[i].vmin+store_idxs[i].vmax
next_bounds = store_idxs[i+1].vmin+store_idxs[i+1].vmax

File diff suppressed because it is too large Load Diff

View File

@@ -8,7 +8,7 @@ from tinygrad.engine.search import Opt, OptOps
from tinygrad.engine.search import time_linearizer, bufs_from_lin
# stuff needed to unpack a kernel
from tinygrad.ops import UOp, UOps, BinaryOps, UnaryOps
from tinygrad.ops import UOp, Ops, BinaryOps, UnaryOps
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
@@ -24,143 +24,143 @@ def _test_overflow(ast, opts):
@unittest.skip("unneeded without launch bounds")
class TestLinearizerOverflow(unittest.TestCase):
def test_overflow_1(self):
ast = UOp(UOps.SINK, None, arg=None, src=(
UOp(UOps.STORE, None, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, 64), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
ast = UOp(Ops.SINK, None, arg=None, src=(
UOp(Ops.STORE, None, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, 64), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
x16:=ast_const(dtypes.float, 0.0, st_src=(
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(UOps.ALU, dtypes.float, arg=UnaryOps.SQRT, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(Ops.ALU, dtypes.float, arg=UnaryOps.SQRT, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
x23:=ast_const(dtypes.float, 1.0, st_src=(
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.ALU, dtypes.float, arg=UnaryOps.RECIP, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
x23,
ast_const(dtypes.float, 1e-05, st_src=(
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
x16,)),)),))
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=0)]
_test_overflow(ast, opts)
# From BEAM on hlb_cifar.py
def test_overflow_2(self):
ast = UOp(UOps.SINK, None, arg=None, src=(
UOp(UOps.STORE, None, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 1, 1, 1), strides=(65536, 0, 1024, 32, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 512, 1, 32, 4, 34, 4, 34), strides=(0, 32768, 0, 1024, 0, 32, 0, 1), offset=-33, mask=((0, 1), (0, 512), (0, 1), (0, 32), (0, 4), (1, 33), (0, 4), (1, 33)), contiguous=False), View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(591872, 0, 0, 136, 1, 18496, 4760, 35), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(0, 0, 288, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
ast = UOp(Ops.SINK, None, arg=None, src=(
UOp(Ops.STORE, None, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 1, 1, 1), strides=(65536, 0, 1024, 32, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 512, 1, 32, 4, 34, 4, 34), strides=(0, 32768, 0, 1024, 0, 32, 0, 1), offset=-33, mask=((0, 1), (0, 512), (0, 1), (0, 32), (0, 4), (1, 33), (0, 4), (1, 33)), contiguous=False), View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(591872, 0, 0, 136, 1, 18496, 4760, 35), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(0, 0, 288, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=0)]
_test_overflow(ast, opts)
# from BEAM on default simple_conv.py (which is quite large):
def test_overflow_3(self):
ast = UOp(UOps.SINK, None, arg=None, src=(
UOp(UOps.STORE, None, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 16, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 16), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
ast = UOp(Ops.SINK, None, arg=None, src=(
UOp(Ops.STORE, None, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 16, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 16), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=2)]
_test_overflow(ast, opts)
# from BEAM on BS=4 simple_conv.py:
def test_overflow_4(self):
ast = UOp(UOps.SINK, None, arg=None, src=(
UOp(UOps.STORE, None, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 4, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 4), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
ast = UOp(Ops.SINK, None, arg=None, src=(
UOp(Ops.STORE, None, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 4, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 4), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)]
_test_overflow(ast, opts)
# from BEAM on BS=2 simple_conv.py:
def test_overflow_5(self):
ast = UOp(UOps.SINK, None, arg=None, src=(
UOp(UOps.STORE, None, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 2), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
ast = UOp(Ops.SINK, None, arg=None, src=(
UOp(Ops.STORE, None, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 2), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=2)]
_test_overflow(ast, opts)
# from BEAM on BS=3 simple_conv.py:
def test_overflow_6(self):
ast = UOp(UOps.SINK, None, arg=None, src=(
UOp(UOps.STORE, None, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
ast = UOp(Ops.SINK, None, arg=None, src=(
UOp(Ops.STORE, None, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)]
_test_overflow(ast, opts)
# from BEAM on BS=3 simple_conv.py: (alt)
def test_overflow_7(self):
ast = UOp(UOps.SINK, None, arg=None, src=(
UOp(UOps.STORE, None, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
ast = UOp(Ops.SINK, None, arg=None, src=(
UOp(Ops.STORE, None, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)]
_test_overflow(ast, opts)
@@ -169,26 +169,26 @@ class TestLinearizerOverflow(unittest.TestCase):
class TestLinearizerOverflowAlt(unittest.TestCase):
def test_overflow_1(self):
BS = 2
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
in_st_1 = ShapeTracker(views=(View(shape=(1, BS, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, BS), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False),
View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))).to_uop()
in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop()
ot_st = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)).to_uop()
prod = UOp(UOps.LOAD, dtypes.float, (g1, in_st_1)) * UOp(UOps.LOAD, dtypes.float, (g2, in_st_2))
store = UOp(UOps.STORE, src=(g0, ot_st, UOp(UOps.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5)))))
ast = UOp(UOps.SINK, src=(store,))
prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2))
store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5)))))
ast = UOp(Ops.SINK, src=(store,))
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=2)]
_test_overflow(ast, opts)
def test_overflow_2(self):
BS = 2
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
in_st_1 = ShapeTracker(views=(View(shape=(1, BS, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, BS), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False),
View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))).to_uop()
in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop()
ot_st = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)).to_uop()
prod = UOp(UOps.LOAD, dtypes.float, (g1, in_st_1)) * UOp(UOps.LOAD, dtypes.float, (g2, in_st_2))
store = UOp(UOps.STORE, src=(g0, ot_st, UOp(UOps.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5)))))
ast = UOp(UOps.SINK, src=(store,))
prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2))
store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5)))))
ast = UOp(Ops.SINK, src=(store,))
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=4, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=5, amt=2)]
_test_overflow(ast, opts)

View File

@@ -1,7 +1,7 @@
import unittest, functools, random
from typing import List
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes
from tinygrad.ops import MetaOps, ReduceOps, BinaryOps, UOps
from tinygrad.ops import MetaOps, ReduceOps, BinaryOps, Ops
from tinygrad.helpers import CI, getenv, prod, Context
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.engine.schedule import create_schedule
@@ -617,24 +617,24 @@ class TestMultiTensor(unittest.TestCase):
t = t + 1
for si in t.schedule():
ast = si.ast.src[0]
assert ast.op is UOps.STORE
assert ast.op is Ops.STORE
assert ast.src[2].arg is BinaryOps.ADD
assert ast.src[2].src[0].op is UOps.LOAD
assert ast.src[2].src[1].src[1].op is UOps.CONST and ast.src[2].src[1].src[1].arg == 1
assert ast.src[2].src[0].op is Ops.LOAD
assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 1
t = 2 * t
for si in t.schedule():
ast = si.ast.src[0]
assert ast.op is UOps.STORE
assert ast.op is Ops.STORE
assert ast.src[2].arg is BinaryOps.MUL
assert ast.src[2].src[0].src[1].op is UOps.CONST and ast.src[2].src[0].src[1].arg == 2
assert ast.src[2].src[1].op is UOps.LOAD
assert ast.src[2].src[0].src[1].op is Ops.CONST and ast.src[2].src[0].src[1].arg == 2
assert ast.src[2].src[1].op is Ops.LOAD
t = t + t.full_like(3)
for si in t.schedule():
ast = si.ast.src[0]
assert ast.op is UOps.STORE
assert ast.op is Ops.STORE
assert ast.src[2].arg is BinaryOps.ADD
assert ast.src[2].src[0].op is UOps.LOAD
assert ast.src[2].src[1].src[1].op is UOps.CONST and ast.src[2].src[1].src[1].arg == 3
assert ast.src[2].src[0].op is Ops.LOAD
assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 3
def test_shard_memory(self):
devices = (d0, d1, d2, d3)

View File

@@ -3,7 +3,7 @@ import unittest
import numpy as np
import torch
from tinygrad import Tensor, Device, TinyJit
from tinygrad.ops import UOps
from tinygrad.ops import Ops
from tinygrad.helpers import CI, Context
from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding
from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm, LSTMCell
@@ -517,7 +517,7 @@ class TestNN(unittest.TestCase):
[12, 19, 8, 1]])
result = layer(a)
schedule = create_schedule([result.lazydata])
self.assertEqual(3, len([item for item in schedule if item.ast.op is UOps.SINK]), "first run realizes arange, weight, and embedding")
self.assertEqual(3, len([item for item in schedule if item.ast.op is Ops.SINK]), "first run realizes arange, weight, and embedding")
run_schedule(schedule)
b = Tensor([[1, 2, 3],
@@ -525,7 +525,7 @@ class TestNN(unittest.TestCase):
[7, 8, 9]])
result = layer(b)
schedule = create_schedule([result.lazydata])
self.assertEqual(1, len([item for item in schedule if item.ast.op is UOps.SINK]), "second run realizes embedding only")
self.assertEqual(1, len([item for item in schedule if item.ast.op is Ops.SINK]), "second run realizes embedding only")
run_schedule(schedule)
def test_embedding_shape(self):

View File

@@ -8,7 +8,7 @@ from tinygrad.dtype import dtypes
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import dedup, flatten, prod
from tinygrad.renderer.cstyle import CStyleLanguage
from tinygrad.ops import BinaryOps, UOp, UOps
from tinygrad.ops import BinaryOps, UOp, Ops
from tinygrad.renderer import Program
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.lazy import LazyBuffer
@@ -20,7 +20,7 @@ def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
def _recursive_add(uop:UOp) -> List[UOp]: return flatten([_recursive_add(x) for x in uop.src])+[uop]
uops = dedup(flatten(_recursive_add(st) for st in stores))
outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is UOps.STORE]
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
inbufs = [cast(LazyBuffer,x.lazydata).base.buffer for x in inputs]
src = Device[Device.DEFAULT].renderer.render("test", uops)
ei = CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops, local_size=local_size))
@@ -30,13 +30,13 @@ def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, CStyleLanguage), "uops are for cstyle")
class TestCStyleFailures(unittest.TestCase):
def test_inline_const_alu(self):
a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
b = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
b = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
idx = UOp.const(dtypes.int, 0)
ld = UOp(UOps.LOAD, dtypes.int, (b, idx))
ld = UOp(Ops.LOAD, dtypes.int, (b, idx))
alu = ld.alu(BinaryOps.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1))
store = UOp.store(a, idx, alu)
sink = UOp(UOps.SINK, dtypes.void, (store,))
sink = UOp(Ops.SINK, dtypes.void, (store,))
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
# CLANG doesn't use the max function
ret = _test_uop_result([Tensor([1])], uops)[0]
@@ -45,21 +45,21 @@ class TestCStyleFailures(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local and Device.DEFAULT == "PTX", "need local")
class TestPTXFailures(unittest.TestCase):
def test_gated_store_with_alu(self):
a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
gated_alu_store = UOp(UOps.STORE, dtypes.void, (a, lidx0, UOp.const(dtypes.int, 1), gate_alu))
sink = UOp(UOps.SINK, dtypes.void, (gated_alu_store,))
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a, lidx0, UOp.const(dtypes.int, 1), gate_alu))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
np.testing.assert_equal(ret, [0, 1, 1, 1])
def test_gated_store_with_if(self):
a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
a = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
val = UOp.const(dtypes.int, 1)
if_uop = UOp(UOps.IF, dtypes.void, (gate_alu,))
gated_alu_store = UOp(UOps.STORE, dtypes.void, (a, lidx0, val, if_uop))
sink = UOp(UOps.SINK, dtypes.void, (gated_alu_store,))
if_uop = UOp(Ops.IF, dtypes.void, (gate_alu,))
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a, lidx0, val, if_uop))
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
np.testing.assert_equal(ret, [0, 1, 1, 1])

View File

@@ -12,7 +12,7 @@ from tinygrad import nn, dtypes, Device, Tensor
from tinygrad.dtype import DType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite, track_rewrites
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, Ops, graph_rewrite, track_rewrites
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left
@@ -29,7 +29,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], LazyBuffer], allowed:int, to_pr
if to_prerealize:
for pre in to_prerealize: pre.schedule()
sched = create_schedule(outs)
if filter_sink: sched = [s for s in sched if s.ast.op is UOps.SINK]
if filter_sink: sched = [s for s in sched if s.ast.op is Ops.SINK]
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
if len(sched) != allowed or DEBUG >= 3:
for i, s in enumerate(sched):
@@ -38,7 +38,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], LazyBuffer], allowed:int, to_pr
if len(sched) != allowed: raise KernelCountException(f"{len(sched)=} != {allowed}")
# test the (sink) ops linearize
for s in sched:
if s.ast.op is not UOps.SINK: continue
if s.ast.op is not Ops.SINK: continue
l = Kernel(s.ast)
l.hand_coded_optimizations()
l.linearize()
@@ -58,7 +58,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
dtypes.default_float = old_default_float
with Context(**kwargs): s = create_schedule([ret.lazydata, img.grad.lazydata, w.grad.lazydata])
run_schedule(s.copy())
cnt = len([si for si in s if si.ast.op is UOps.SINK])
cnt = len([si for si in s if si.ast.op is Ops.SINK])
assert cnt == allowed, f"expected {allowed} kernels, got {cnt}"
if getenv("CHECK", 1):
import torch
@@ -191,7 +191,7 @@ class TestSchedule(unittest.TestCase):
r1 = (x - r0).sum(axis=0).div(2)
out = r0 + r1
schedule = check_schedule(out, 2)
reduceops = [x for si in schedule for x in si.ast.parents if x.op is UOps.REDUCE_AXIS]
reduceops = [x for si in schedule for x in si.ast.parents if x.op is Ops.REDUCE_AXIS]
assert len(reduceops) == 2
def test_cache_reduce_multiple_children(self):
@@ -202,7 +202,7 @@ class TestSchedule(unittest.TestCase):
out0 = r0 + y
out1 = r1 + y
schedule = check_schedule([out0, out1], 4)
reduceops = [x for si in schedule for x in si.ast.parents if x.op is UOps.REDUCE_AXIS]
reduceops = [x for si in schedule for x in si.ast.parents if x.op is Ops.REDUCE_AXIS]
assert len(reduceops) == 2
def test_fold_double_unary(self):
@@ -1108,7 +1108,7 @@ class TestSchedule(unittest.TestCase):
a = Tensor.empty(16, 16)
b = (a.sum(0) + a.max(1)) + 2
schedule = check_schedule(b, 2)
self.assertIs(schedule[0].ast.src[0].src[2].op, UOps.REDUCE_AXIS)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.REDUCE_AXIS)
# multireduce spec
def test_multireduce_midreduce_nochase(self):
@@ -1117,7 +1117,7 @@ class TestSchedule(unittest.TestCase):
b = (a.sum(0)+a.max(0) + a.max(1)+a.sum(1)) + 2
# schedule = check_schedule(b, 2)
schedule = check_schedule(b, 4)
self.assertIs(schedule[0].ast.src[0].src[2].op, UOps.REDUCE_AXIS)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.REDUCE_AXIS)
run_schedule(schedule)
np.testing.assert_allclose(b.numpy(), a.numpy().sum(0)+a.numpy().max(0) + a.numpy().max(1)+a.numpy().sum(1)+2, atol=1e-4, rtol=1e-4)
@@ -1352,7 +1352,7 @@ class TestIndexing(unittest.TestCase):
with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):
lst = [xt] if isinstance(xt, Tensor) else xt
s = Tensor.schedule(*lst)
kernels = [si for si in s if si.ast.op is UOps.SINK]
kernels = [si for si in s if si.ast.op is Ops.SINK]
for si in kernels: verify_ast(si.ast)
run_schedule(s)
if FUSE_ARANGE: self.assertEqual(len(kernels), cnt)
@@ -1607,20 +1607,20 @@ class TestIndexing(unittest.TestCase):
self.assertLess(et, 1200)
def test_no_rewrite_elementwise(self):
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop()))
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),))
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop()))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),))
rsink = graph_rewrite(sink, view_right)
self.assertEqual(rsink.key, sink.key)
def test_simple_store_reshape(self):
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
r = r + ast_const(dtypes.int, 2, ())
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
rsink = graph_rewrite(sink, view_right)
# NOTE: this AST is always correct in the entire lifecycle of graph_rewrite!
# with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
@@ -1628,21 +1628,21 @@ class TestIndexing(unittest.TestCase):
verify_ast(rsink)
def test_no_reshape_reduceop(self):
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
rsink = graph_rewrite(sink, view_right)
verify_ast(sink)
self.assertEqual(sink.key, rsink.key)
def test_reshape_many(self):
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
for _ in range(24): r = r + ast_const(dtypes.int, 2, ())
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
rsink, et = timeit(graph_rewrite, sink, view_right)
# NOTE: this AST is always correct in the entire lifecycle of graph_rewrite!
# with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
@@ -1656,11 +1656,11 @@ class TestIndexing(unittest.TestCase):
sizes = [10*(i+1) for i in range(SZ)]
tms: List[float] = []
for sz in sizes:
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
for _ in range(sz): r = r + ast_const(dtypes.int, 2, ())
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
rsink, et = timeit(graph_rewrite, sink, view_right)
with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
verify_ast(rsink)
@@ -1676,20 +1676,20 @@ class TestIndexing(unittest.TestCase):
def test_swizzle_rewrite(self):
# graph rewrite
sink = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(UOps.VIEW, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
x8:=UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
sink = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(Ops.VIEW, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
x8:=UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
x8,
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
# verify output
k = Kernel(sink)
@@ -1705,13 +1705,13 @@ class TestIndexing(unittest.TestCase):
a = Tensor.randint(4,).realize()
expected_out = a.numpy().sum(0)+1
# LazyBuffer to pre-rewrite AST
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0,)))
swizzle_r = UOp(UOps.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(()))
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0,)))
swizzle_r = UOp(Ops.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(()))
const = ast_const(dtypes.int, 1, ())
alu = swizzle_r+const
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),))
# graph rewrite
sink = graph_rewrite(sink, view_right)
# verify output
@@ -1728,13 +1728,13 @@ class TestIndexing(unittest.TestCase):
b = Tensor.randint(4,).realize()
expected_out = a.numpy().sum(0)+b.numpy().sum(0)+2
# LazyBuffer to pre-rewrite AST
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
r1 = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld1,), (BinaryOps.ADD, (0,)))
ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop()))
r2 = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld2,), (BinaryOps.ADD, (0,)))
alu = UOp(UOps.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(UOps.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(()))
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+ast_const(dtypes.int, 2, ()),),),)) # noqa: E501
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld1,), (BinaryOps.ADD, (0,)))
ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop()))
r2 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld2,), (BinaryOps.ADD, (0,)))
alu = UOp(Ops.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(Ops.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(()))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+ast_const(dtypes.int, 2, ()),),),)) # noqa: E501
# graph rewrite
sink = graph_rewrite(sink, view_right)
# verify output
@@ -1745,51 +1745,51 @@ class TestIndexing(unittest.TestCase):
np.testing.assert_equal(c.numpy(), expected_out)
def test_swizzle_rewrite_alt(self):
swizzle = UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501
swizzle = UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501
# there's an EXPAND pushing through the REDUCE_AXIS
self.assertGreater(prod(swizzle.st.shape), prod(swizzle.src[0].st.shape))
ret = graph_rewrite(graph_rewrite(swizzle, view_left), view_right)
# EXPAND is rewritten
self.assertEqual(prod(ret.st.shape), prod(ret.src[0].st.shape))
# and pushed to the LOAD
new_load_st = unwrap([x for x in ret.parents if x.op is UOps.VIEW][0].st)
new_load_st = unwrap([x for x in ret.parents if x.op is Ops.VIEW][0].st)
self.assertGreater(prod(new_load_st.shape), prod(ld_st.shape))
self.assertEqual(new_load_st.views[0].strides, (0, 9, 3, 0, 1, 0, 27))
def test_permute_rewrite(self):
sink = UOp(UOps.STORE, dtypes.void, arg=None, src=(
x1:=UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(1, ('METAL', 16384, dtypes.float)), src=()),
x2:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.CONTIGUOUS, dtypes.float, arg=None, src=(
sink = UOp(Ops.STORE, dtypes.void, arg=None, src=(
x1:=UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(1, ('METAL', 16384, dtypes.float)), src=()),
x2:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CONTIGUOUS, dtypes.float, arg=None, src=(
x1,
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 32, 1, 1024), offset=0, mask=None, contiguous=False),)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
x11:=UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 16384, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 32, 1, 1024), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
x11:=UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 16384, dtypes.float)), src=()),
x2,)),)),
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(8, ('METAL', 256, dtypes.float)), src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(10, ('METAL', 16, dtypes.float)), src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(8, ('METAL', 256, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(10, ('METAL', 16, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
x11,)),)),)),)),))
@track_rewrites()
def rewrite(sink): return graph_rewrite(graph_rewrite(sink, view_left), view_right)
ret = rewrite(sink)
assert len([x for x in ret.sparents if x.op is UOps.VIEW and len(x.src) != 0]) == 0, f"unmerged views left in sink {ret}"
assert len([x for x in ret.sparents if x.op is Ops.VIEW and len(x.src) != 0]) == 0, f"unmerged views left in sink {ret}"
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -3,7 +3,7 @@ import unittest
from test.helpers import ast_const
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.codegen.kernel import Kernel
from tinygrad.ops import UOp, UOps, BinaryOps
from tinygrad.ops import UOp, Ops, BinaryOps
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search
from tinygrad.device import Device, Buffer
@@ -16,15 +16,15 @@ from tinygrad.shape.view import View
class TestTimeLinearizer(unittest.TestCase):
def test_reasonable_time(self):
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is UOps.SINK][0]
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is Ops.SINK][0]
out = Buffer(Device.DEFAULT, si.outputs[0].size, si.outputs[0].dtype).allocate()
memops = {x.src[0].arg:x.src[-1].arg.real_size() for x in si.ast.parents if x.op is UOps.LOAD}
memops = {x.src[0].arg:x.src[-1].arg.real_size() for x in si.ast.parents if x.op is Ops.LOAD}
rawbufs = [out] + [Buffer(Device.DEFAULT, memops[i], x.dtype).allocate() for i,x in enumerate(si.inputs, start=len(si.outputs))]
tm = time_linearizer(Kernel(si.ast), rawbufs, allow_test_size=False, cnt=10, disable_cache=True)
assert tm > 0 and tm != float('inf')
def test_bufs_from_lin(self):
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is UOps.SINK][0]
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is Ops.SINK][0]
rawbufs = bufs_from_lin(lin:=Kernel(si.ast))
assert len(rawbufs) == len(lin.membufs) == 2
assert all(r is not None for r in rawbufs)
@@ -34,7 +34,7 @@ class TestTimeLinearizer(unittest.TestCase):
def test_bufs_from_lin_alt(self):
a = Tensor.randn(4, 4).realize()
b = a+a[0]
si = [si for si in b.schedule() if si.ast.op is UOps.SINK][0]
si = [si for si in b.schedule() if si.ast.op is Ops.SINK][0]
rawbufs = bufs_from_lin(k:=Kernel(si.ast))
assert len(rawbufs) == len(k.membufs) == 2
assert all(r is not None for r in rawbufs)
@@ -46,12 +46,12 @@ class TestTimeLinearizer(unittest.TestCase):
Ensure that the kernel count is not incremented by time_linearizer when clearing l2
"""
# ast of Tensor.zeros(16).contiguous().realize()
ast = UOp(UOps.SINK, src=(
UOp(UOps.STORE, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),))),
ast = UOp(Ops.SINK, src=(
UOp(Ops.STORE, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),))),
ast_const(dtypes.float, 0.0, st_src=(
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(0,), offset=0, mask=None, contiguous=False),))),)),)),))
UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(0,), offset=0, mask=None, contiguous=False),))),)),)),))
lin = Kernel(ast)
bufs = bufs_from_lin(lin)
@@ -103,37 +103,37 @@ class TestBEAM(unittest.TestCase):
def test_filter_global_buffer(self):
# taken from https://github.com/tinygrad/tinygrad/issues/4612
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (1,)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)), # noqa: E501
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (1,)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)), # noqa: E501
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
ast_const(dtypes.float, 1.4285714285714286, st_src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) # noqa: E501
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) # noqa: E501
lin = Kernel(ast)
bufs = bufs_from_lin(lin)

View File

@@ -2,7 +2,7 @@ from typing import List
import unittest, time
from tinygrad import dtypes, Device
from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, Ops, UOp, KernelInfo
from tinygrad.ops import UPat, PatternMatcher
from tinygrad.renderer import Renderer
from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index
@@ -24,30 +24,30 @@ class TestGraphRewriteEfficiency(unittest.TestCase):
c1 = UOp.const(dtypes.int, 1)
c2 = UOp.const(dtypes.int, 2)
st = time.perf_counter()
uops = [UOp(UOps.ALU, dtypes.int, (c1, c2), BinaryOps.ADD) for _ in range(10000)]
uops = [UOp(Ops.ALU, dtypes.int, (c1, c2), BinaryOps.ADD) for _ in range(10000)]
et = time.perf_counter() - st
print(f"created {len(uops)} uops in {et*1000:.2f} ms")
def test_expand_rewrite(self):
sink = UOp(UOps.SINK, dtypes.void, arg=KernelInfo(local_dims=2, upcasted=4, dont_use_locals=False), src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 4, 64, 8, 16, 1, 1, 3, 3, 4, 1),
sink = UOp(Ops.SINK, dtypes.void, arg=KernelInfo(local_dims=2, upcasted=4, dont_use_locals=False), src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 4, 64, 8, 16, 1, 1, 3, 3, 4, 1),
strides=(1179648, 9216, 1, 147456, 576, 0, 0, 64, 192, 36864, 0),
offset=0, mask=None, contiguous=False),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 10)), src=(
UOp(UOps.CAST, dtypes.float, arg=None, src=(
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 10)), src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(
View(shape=(1, 1024, 1, 64, 4, 17, 4, 17), strides=(0, 14400, 0, 225, 0, 15, 0, 1), offset=-16,
mask=((0, 1), (0, 1024), (0, 1), (0, 64), (0, 4), (1, 16), (0, 4), (1, 16)), contiguous=False),
View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(0, 73984, 4734976, 0, 4624, 295936, 68, 18, 1224, 0, 1), offset=0,
mask=None, contiguous=False))), src=()),)),
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(
UOp(Ops.LOAD, dtypes.half, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(
View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(7200, 0, 230400, 900, 0, 14400, 15, 0, 0, 225, 1), offset=0,
mask=None, contiguous=False),)), src=()),)),)),)),)),)),))
lower_sink = rewrite_shapetracker_with_index(sink, Device[Device.DEFAULT].renderer)
@@ -82,7 +82,7 @@ class TestGraphRewriteConst(unittest.TestCase):
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
v2 = UOp.const(dtypes.int.vec(3), (5,6,7))
ret = graph_rewrite(v1+v2, sym)
self.assertEqual(ret.op, UOps.VCONST)
self.assertEqual(ret.op, Ops.VCONST)
self.assertEqual(ret.dtype, dtypes.int.vec(3))
self.assertEqual(ret.arg, (5,7,9))
@@ -90,31 +90,31 @@ class TestGraphRewriteConst(unittest.TestCase):
v1 = UOp.const(dtypes.int.vec(3), (0,1,2))
v2 = UOp.const(dtypes.int.vec(3), (2,1,0))
ret = graph_rewrite(v1+v2, sym)
self.assertEqual(ret.op, UOps.CONST)
self.assertEqual(ret.op, Ops.CONST)
self.assertEqual(ret.dtype, dtypes.int.vec(3))
self.assertEqual(ret.arg, 2)
class TestGraphRewrite(unittest.TestCase):
def test_dedup(self):
v1 = UOp(UOps.DEFINE_VAR, dtypes.float)
v2 = UOp(UOps.DEFINE_VAR, dtypes.float)
v1 = UOp(Ops.DEFINE_VAR, dtypes.float)
v2 = UOp(Ops.DEFINE_VAR, dtypes.float)
nout = graph_rewrite(v1+v2, PatternMatcher([]))
self.assertIs(nout.src[0], nout.src[1])
# NOTE: this shows why we can't have a UOp in arg
@unittest.expectedFailure
def test_no_dedup_args(self):
a1 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("a1", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11)))
a2 = UOp(UOps.DEFINE_VAR, dtypes.int, (), ("a2", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11)))
a1 = UOp(Ops.DEFINE_VAR, dtypes.int, (), ("a1", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11)))
a2 = UOp(Ops.DEFINE_VAR, dtypes.int, (), ("a2", UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 11)))
sink = a1.sink(a2)
define_vars = [x for x in graph_rewrite(sink, PatternMatcher([])).sparents if x.op is UOps.DEFINE_VAR]
define_vars = [x for x in graph_rewrite(sink, PatternMatcher([])).sparents if x.op is Ops.DEFINE_VAR]
self.assertEqual(len(define_vars), 1)
def test_simple(self):
c1 = UOp.const(dtypes.float, 1.0)
c2 = UOp.const(dtypes.float, 2.0)
nout = graph_rewrite(c1+c2, simple_pm)
self.assertEqual(nout.op, UOps.CONST)
self.assertEqual(nout.op, Ops.CONST)
self.assertEqual(nout.arg, 3.0)
def test_depth_2_late(self):
@@ -122,7 +122,7 @@ class TestGraphRewrite(unittest.TestCase):
c2 = UOp.const(dtypes.float, 2.0)
c3 = UOp.const(dtypes.float, 3.0)
nout = graph_rewrite(c1*c2*(c3+c3), simple_pm)
self.assertEqual(nout.op, UOps.CONST)
self.assertEqual(nout.op, Ops.CONST)
self.assertEqual(nout.arg, 12.0)
def test_double(self):
@@ -130,7 +130,7 @@ class TestGraphRewrite(unittest.TestCase):
c2 = UOp.const(dtypes.float, 2.0)
c3 = UOp.const(dtypes.float, 3.0)
nout = graph_rewrite(c1+c2+c3, simple_pm)
self.assertEqual(nout.op, UOps.CONST)
self.assertEqual(nout.op, Ops.CONST)
self.assertEqual(nout.arg, 6.0)
def test_triple(self):
@@ -139,7 +139,7 @@ class TestGraphRewrite(unittest.TestCase):
c3 = UOp.const(dtypes.float, 3.0)
c4 = UOp.const(dtypes.float, 4.0)
nout = graph_rewrite(c1+c2+c3+c4, simple_pm)
self.assertEqual(nout.op, UOps.CONST)
self.assertEqual(nout.op, Ops.CONST)
self.assertEqual(nout.arg, 10.0)
def test_diamond(self):
@@ -147,23 +147,23 @@ class TestGraphRewrite(unittest.TestCase):
c2 = UOp.const(dtypes.float, 2.0)
c3 = UOp.const(dtypes.float, 3.0)
nout = graph_rewrite((c1+c2)+(c1+c3), simple_pm)
self.assertEqual(nout.op, UOps.CONST)
self.assertEqual(nout.op, Ops.CONST)
self.assertEqual(nout.arg, 7.0)
def test_magic_4(self):
c1 = UOp.const(dtypes.int, 4.0)
nout = graph_rewrite(c1, simple_pm)
self.assertEqual(nout.op, UOps.CONST)
self.assertEqual(nout.op, Ops.CONST)
self.assertEqual(nout.arg, 3.0)
def test_depth_2_fold(self):
v = UOp(UOps.DEFINE_VAR, dtypes.float)
v = UOp(Ops.DEFINE_VAR, dtypes.float)
c1 = UOp.const(dtypes.float, 1.0)
c2 = UOp.const(dtypes.float, 2.0)
nout = graph_rewrite(v+c1+c2, simple_pm)
self.assertEqual(nout.op, UOps.ALU)
self.assertEqual(nout.src[0].op, UOps.DEFINE_VAR)
self.assertEqual(nout.src[1].op, UOps.CONST)
self.assertEqual(nout.op, Ops.ALU)
self.assertEqual(nout.src[0].op, Ops.DEFINE_VAR)
self.assertEqual(nout.src[1].op, Ops.CONST)
self.assertEqual(nout.src[1].arg, 3.0)
def test_commutative_work(self):
@@ -182,77 +182,77 @@ class TestGraphRewrite(unittest.TestCase):
b = UOp.variable('b', 0, 1)
c = UOp.variable('c', 0, 1)
d = UOp.variable('d', 0, 1)
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b]
outs = [2+a, 2+a+d+3+b+c+4, UOp(Ops.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b]
for out in outs:
sink = graph_rewrite(out, sym)
print(sink.render())
self.assertEqual(sink.op, UOps.ALU)
self.assertEqual(sink.src[1].op, UOps.CONST)
self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 1)
self.assertEqual(sink.op, Ops.ALU)
self.assertEqual(sink.src[1].op, Ops.CONST)
self.assertEqual(len([x for x in sink.sparents if x.op is Ops.CONST]), 1)
class TestUOpGraph(unittest.TestCase):
def test_add_constant_fold(self):
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
out = UOp(UOps.ALU, dtypes.float, (c1, c2), BinaryOps.ADD)
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
out = UOp(Ops.ALU, dtypes.float, (c1, c2), BinaryOps.ADD)
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
self.assertEqual(out.op, UOps.CONST)
self.assertEqual(out.op, Ops.CONST)
self.assertEqual(out.arg, 3.0)
def test_where_same_fold(self):
v = UOp.variable('tmp', 0, 1)
c0 = UOp(UOps.CONST, dtypes.int, arg=0)
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
out = UOp(UOps.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE)
c0 = UOp(Ops.CONST, dtypes.int, arg=0)
vc = UOp(Ops.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
out = UOp(Ops.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE)
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
self.assertEqual(out.op, UOps.CONST)
self.assertEqual(out.op, Ops.CONST)
self.assertEqual(out.arg, 1.0)
def test_where_const_fold(self):
bf = UOp(UOps.CONST, dtypes.bool, arg=False)
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
out = UOp(UOps.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE)
bf = UOp(Ops.CONST, dtypes.bool, arg=False)
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
out = UOp(Ops.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE)
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
self.assertEqual(out.op, UOps.CONST)
self.assertEqual(out.op, Ops.CONST)
self.assertEqual(out.arg, 2.0)
def test_const_cast(self):
bf = UOp(UOps.CONST, dtypes.bool, arg=False)
out = UOp(UOps.CAST, dtypes.int, (bf,))
bf = UOp(Ops.CONST, dtypes.bool, arg=False)
out = UOp(Ops.CAST, dtypes.int, (bf,))
uops = to_uops_list([out])
self.assertEqual(len(uops), 1)
out = uops[-1]
self.assertEqual(out.op, UOps.CONST)
self.assertEqual(out.op, Ops.CONST)
self.assertEqual(out.arg, 0)
@unittest.skip("this test isn't valid uops")
def test_noop_vectorize_fold(self):
d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
idx = UOp.const(dtypes.int, 0)
ld = UOp(UOps.LOAD, dtypes.float.vec(2), (d0, idx))
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(2), (ld,))
x = UOp(UOps.GEP, dtypes.float, (vec, ), arg=0)
alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
out = UOp(UOps.STORE, dtypes.void, (d0, idx, alu))
ld = UOp(Ops.LOAD, dtypes.float.vec(2), (d0, idx))
vec = UOp(Ops.VECTORIZE, dtypes.float.vec(2), (ld,))
x = UOp(Ops.GEP, dtypes.float, (vec, ), arg=0)
alu = UOp(Ops.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
out = UOp(Ops.STORE, dtypes.void, (d0, idx, alu))
uops = to_uops_list([out])
self.assertEqual(len([x for x in uops if x.op is UOps.VECTORIZE]), 0)
self.assertEqual(len([x for x in uops if x.op is Ops.VECTORIZE]), 0)
def test_gep_vec_fold(self):
d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
d1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
d2 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 2)
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
d2 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 2)
idx = UOp.const(dtypes.int, 0)
def _test_vec(geps, count=4):
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(count), geps)
out = UOp(UOps.STORE, dtypes.void, (d0, idx, vec))
vec = UOp(Ops.VECTORIZE, dtypes.float.vec(count), geps)
out = UOp(Ops.STORE, dtypes.void, (d0, idx, vec))
uops = to_uops_list([out])
if DEBUG >= 4:
from tinygrad import Device
@@ -260,53 +260,53 @@ class TestUOpGraph(unittest.TestCase):
return uops[-1].src[-1]
# possible
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
xyzw = tuple(UOp(UOps.GEP, dtypes.float, (val,), (i,)) for i in range(4))
self.assertIs(_test_vec(xyzw).op, UOps.LOAD)
val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1, idx))
xyzw = tuple(UOp(Ops.GEP, dtypes.float, (val,), (i,)) for i in range(4))
self.assertIs(_test_vec(xyzw).op, Ops.LOAD)
# unaligned
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
wzyx = tuple(UOp(UOps.GEP, dtypes.float, (val,), (i,)) for i in reversed(range(4)))
self.assertIs(_test_vec(wzyx).op, UOps.VECTORIZE)
val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1, idx))
wzyx = tuple(UOp(Ops.GEP, dtypes.float, (val,), (i,)) for i in reversed(range(4)))
self.assertIs(_test_vec(wzyx).op, Ops.VECTORIZE)
# different_size
val = UOp(UOps.LOAD, dtypes.float.vec(2), (d1, idx))
xy = tuple(UOp(UOps.GEP, dtypes.float, (val, ), (i,)) for i in range(2))
self.assertIs(_test_vec(xy+xy).op, UOps.VECTORIZE)
val = UOp(UOps.LOAD, dtypes.float.vec(4), (d1, idx))
xy = tuple(UOp(UOps.GEP, dtypes.float, (val, ), (i,)) for i in range(2))
self.assertIs(_test_vec(xy, count=2).op, UOps.VECTORIZE)
val = UOp(Ops.LOAD, dtypes.float.vec(2), (d1, idx))
xy = tuple(UOp(Ops.GEP, dtypes.float, (val, ), (i,)) for i in range(2))
self.assertIs(_test_vec(xy+xy).op, Ops.VECTORIZE)
val = UOp(Ops.LOAD, dtypes.float.vec(4), (d1, idx))
xy = tuple(UOp(Ops.GEP, dtypes.float, (val, ), (i,)) for i in range(2))
self.assertIs(_test_vec(xy, count=2).op, Ops.VECTORIZE)
# different vals
val1 = UOp(UOps.LOAD, dtypes.float.vec(2), (d1, idx))
val2 = UOp(UOps.LOAD, dtypes.float.vec(2), (d2, idx))
xy1 = tuple(UOp(UOps.GEP, dtypes.float, (val1, ), (i,)) for i in range(2))
xy2 = tuple(UOp(UOps.GEP, dtypes.float, (val2, ), (i,)) for i in range(2))
self.assertIs(_test_vec(xy1+xy2).op, UOps.VECTORIZE)
val1 = UOp(Ops.LOAD, dtypes.float.vec(2), (d1, idx))
val2 = UOp(Ops.LOAD, dtypes.float.vec(2), (d2, idx))
xy1 = tuple(UOp(Ops.GEP, dtypes.float, (val1, ), (i,)) for i in range(2))
xy2 = tuple(UOp(Ops.GEP, dtypes.float, (val2, ), (i,)) for i in range(2))
self.assertIs(_test_vec(xy1+xy2).op, Ops.VECTORIZE)
def test_gep_vec_const_fold(self):
for vec_size in [2, 4, 8]:
consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)]
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
uops = to_uops_list([UOp(UOps.GEP, dtypes.float, (vec,), (i,)) for i in range(vec_size)])
vec = UOp(Ops.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
uops = to_uops_list([UOp(Ops.GEP, dtypes.float, (vec,), (i,)) for i in range(vec_size)])
for uop, const in zip(uops, consts):
self.assertEqual(uop, const)
def test_wmma_vectorize_fold(self):
for i in [2, 4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i))
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[0], acc)
self.assertEqual(len(uops), 1)
for i in [2, 4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i))
vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[0], acc)
self.assertEqual(len(uops), 1)
@@ -314,109 +314,109 @@ class TestUOpGraph(unittest.TestCase):
@unittest.skip("wmma is wrong here, it needs an arg")
def test_wmma_vectorize_no_fold(self):
for i in [4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
tuple(UOp(Ops.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[-1], wmma)
for i in [4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
tuple(UOp(Ops.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[-1], wmma)
for i in [2, 4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[-1], wmma)
for i in [2, 4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
var = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
vec = UOp(Ops.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
acc = UOp(Ops.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(Ops.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
self.assertEqual(uops[-1], wmma)
def test_cast_alu_fold(self):
d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0)
d1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0)
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
idx = UOp.const(dtypes.int, 0)
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
ld = UOp(Ops.LOAD, dtypes.int, (d1, idx))
alu = ld.lt(1).cast(dtypes.bool)
out = UOp(UOps.STORE, dtypes.void, (d0, idx, alu))
out = UOp(Ops.STORE, dtypes.void, (d0, idx, alu))
uops = to_uops_list([out])
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 0)
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 0)
def test_double_cast_fold(self):
d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
d1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
d0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
d1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
idx = UOp.const(dtypes.int, 0)
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
ld = UOp(Ops.LOAD, dtypes.int, (d1, idx))
alu = ld.cast(dtypes.float).cast(dtypes.float)
out = UOp(UOps.STORE, dtypes.void, (d0, idx, alu))
out = UOp(Ops.STORE, dtypes.void, (d0, idx, alu))
uops = to_uops_list([out])
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1)
self.assertEqual(len([x for x in uops if x.op is Ops.CAST]), 1)
def test_depth_2_const_fold(self):
v = UOp.variable("tmp", 0, 1)
c2 = UOp(UOps.CONST, dtypes.int, arg=2)
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
c2 = UOp(Ops.CONST, dtypes.int, arg=2)
c4 = UOp(Ops.CONST, dtypes.int, arg=4)
vc = UOp(Ops.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
out = UOp(Ops.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
uops = to_uops_list([out])
self.assertEqual(len(uops), 3)
out = uops[-1]
self.assertEqual(out.op, UOps.ALU)
self.assertEqual(out.op, Ops.ALU)
self.assertEqual(out.arg, BinaryOps.ADD)
self.assertEqual(out.src[1].op, UOps.CONST)
self.assertEqual(out.src[1].op, Ops.CONST)
self.assertEqual(out.src[1].arg, 6)
def test_fold_gated_load(self):
glbl0 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
glbl1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
glbl2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2)
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
glbl2 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2)
idx = UOp.const(dtypes.int, 0)
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, False)))
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, True)))
uops = to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, idx, ld1+ld0))])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, False)))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, True)))
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0, idx, ld1+ld0))])
ld0 = uops[-1].src[-1]
# the gate and invalid value are deleted from ld1
self.assertEqual(ld0, UOp.load(glbl2.index(idx), dtype=dtypes.int))
def test_fold_gated_load_local(self):
glbl0 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
smem = UOp(UOps.DEFINE_LOCAL, dtypes.int.ptr(local=True), (), ("temp", 1))
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
st = UOp(UOps.STORE, dtypes.void, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int)))
barrier = UOp(UOps.BARRIER, dtypes.void, (st, ))
ld0 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, False), barrier))
ld1 = UOp(UOps.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, True), barrier))
uops = to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, lidx, ld1+ld0))])
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
smem = UOp(Ops.DEFINE_LOCAL, dtypes.int.ptr(local=True), (), ("temp", 1))
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 16))
st = UOp(Ops.STORE, dtypes.void, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int)))
barrier = UOp(Ops.BARRIER, dtypes.void, (st, ))
ld0 = UOp(Ops.LOAD, dtypes.int, (smem, lidx+1, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, False), barrier))
ld1 = UOp(Ops.LOAD, dtypes.int, (smem, lidx+2, UOp.const(dtypes.int, 0), UOp.const(dtypes.bool, True), barrier))
uops = to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0, lidx, ld1+ld0))])
ld0 = uops[-1].src[-1]
# the gate and invalid value are deleted from ld1
self.assertEqual(ld0.src[0], smem.index(lidx+2))
def test_fold_gated_store(self):
glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
idx0 = UOp.const(dtypes.int, 0)
idx1 = UOp.const(dtypes.int, 0)
val = UOp.const(dtypes.int, 42)
st0 = UOp(UOps.STORE, dtypes.void, (glbl, idx0, val, UOp.const(dtypes.bool, False)))
st1 = UOp(UOps.STORE, dtypes.void, (glbl, idx1, val, UOp.const(dtypes.bool, True)))
st0 = UOp(Ops.STORE, dtypes.void, (glbl, idx0, val, UOp.const(dtypes.bool, False)))
st1 = UOp(Ops.STORE, dtypes.void, (glbl, idx1, val, UOp.const(dtypes.bool, True)))
uops = to_uops_list([st0, st1])
# only the second store happens
self.assertEqual(len(uops), 5)
@@ -424,23 +424,23 @@ class TestUOpGraph(unittest.TestCase):
@unittest.skip("this is a uop type error")
def test_asserts_bad_gate(self):
glbl0 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
idx = UOp.const(dtypes.int, 0)
bad_gate = UOp.const(dtypes.int, 1)
with self.assertRaises(AssertionError): to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
with self.assertRaises(AssertionError): to_uops_list([UOp(Ops.STORE, dtypes.void, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
def test_switched_range_order(self):
glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
c0 = UOp.const(dtypes.int, 0)
c2 = UOp.const(dtypes.int, 2)
cf = UOp.const(dtypes.float, 0.0)
r1 = UOp(UOps.RANGE, dtypes.int, (c0, c2), (1, 0, False))
r2 = UOp(UOps.RANGE, dtypes.int, (c0, c2), (1, 1, False))
alu = UOp(UOps.ALU, dtypes.int, (r2, r1), BinaryOps.MUL)
store = UOp(UOps.STORE, dtypes.void, (glbl, alu, cf))
r1 = UOp(Ops.RANGE, dtypes.int, (c0, c2), (1, 0, False))
r2 = UOp(Ops.RANGE, dtypes.int, (c0, c2), (1, 1, False))
alu = UOp(Ops.ALU, dtypes.int, (r2, r1), BinaryOps.MUL)
store = UOp(Ops.STORE, dtypes.void, (glbl, alu, cf))
uops = to_uops_list([store])
ranges = [x for x in uops if x.op is UOps.RANGE]
endranges = [x for x in uops if x.op is UOps.ENDRANGE]
ranges = [x for x in uops if x.op is Ops.RANGE]
endranges = [x for x in uops if x.op is Ops.ENDRANGE]
# ranges are closed in the right order
self.assertEqual(endranges[-1].src[0], ranges[0])
@@ -449,90 +449,90 @@ def float4_rewrite(sink): return full_graph_rewrite(sink, Renderer())
class TestExpander(unittest.TestCase):
def test_expand_add_broadcast(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
sink = expander_rewrite(e1+3)
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 4
assert sink.op is Ops.EXPAND and len(sink.src[0].arg) == 4
self.assertTupleEqual(sink.src[0].arg, (3,4,5,6))
def test_contract_simple(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
con = UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
sink = expander_rewrite(con)
self.assertEqual(sink.op, UOps.VCONST)
self.assertEqual(sink.op, Ops.VCONST)
self.assertTupleEqual(sink.arg, (0,1,2,3))
def test_contract_axis_1(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
con = UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((1,4),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 16 and sink.arg == ((2,4),)
assert sink.src[0].op is UOps.VCONST
assert sink.op is Ops.EXPAND and len(sink.src[0].arg) == 16 and sink.arg == ((2,4),)
assert sink.src[0].op is Ops.VCONST
self.assertTupleEqual(sink.src[0].arg[0:4], (0,4,8,12))
self.assertTupleEqual(sink.src[0].arg[12:], (3,7,11,15))
def test_contract_axis_2(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2,4),))
e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,4),(2,4)))
con = UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((2,4),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 16 and sink.arg == ((1,4),)
assert sink.src[0].op is UOps.VCONST
assert sink.op is Ops.EXPAND and len(sink.src[0].arg) == 16 and sink.arg == ((1,4),)
assert sink.src[0].op is Ops.VCONST
self.assertTupleEqual(sink.src[0].arg[0:4], (0,1,2,3))
self.assertTupleEqual(sink.src[0].arg[12:], (12,13,14,15))
def test_contract_axis_2_big(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
con = UOp(Ops.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (3, 2), (4, 2))
assert sink.op is Ops.EXPAND and sink.arg == ((1, 2), (3, 2), (4, 2))
self.assertTupleEqual(sink.src[0].arg[0:2], (0,4))
self.assertTupleEqual(sink.src[0].arg[12:14], (10,14))
def test_contract_multi_axis(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((3, 2), (2, 2))))
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(16), tuple(x for x in range(16))),), ((1,2),(2,2),(3,2),(4,2)))
sink = expander_rewrite(UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((3, 2), (2, 2))))
assert sink.op is Ops.EXPAND and sink.arg == ((1, 2), (4, 2))
self.assertTupleEqual(sink.src[0].arg[0:4], (0, 4, 2, 6))
sink = expander_rewrite(UOp(UOps.CONTRACT, dtypes.int.vec(4), (e1,), ((2, 2), (3, 2))))
assert sink.op is UOps.EXPAND and sink.arg == ((1, 2), (4, 2))
sink = expander_rewrite(UOp(Ops.CONTRACT, dtypes.int.vec(4), (e1,), ((2, 2), (3, 2))))
assert sink.op is Ops.EXPAND and sink.arg == ((1, 2), (4, 2))
self.assertTupleEqual(sink.src[0].arg[0:4], (0, 2, 4, 6))
def test_contract_mid(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(8), tuple(x for x in range(8))),), ((1,2),(2,2),(3,2)))
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(8), tuple(x for x in range(8))),), ((1,2),(2,2),(3,2)))
con = UOp(Ops.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
sink = expander_rewrite(con)
assert sink.op is UOps.EXPAND and sink.arg == ((1,2),(3,2))
assert sink.src[0].op is UOps.VCONST and len(sink.src[0].arg) == 8
assert sink.op is Ops.EXPAND and sink.arg == ((1,2),(3,2))
assert sink.src[0].op is Ops.VCONST and len(sink.src[0].arg) == 8
self.assertTupleEqual(sink.src[0].arg, (0,2,1,3,4,6,5,7))
def test_contract_no_expand(self):
e1 = UOp(UOps.DEFINE_VAR, dtypes.int)
con = UOp(UOps.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
e1 = UOp(Ops.DEFINE_VAR, dtypes.int)
con = UOp(Ops.CONTRACT, dtypes.int.vec(2), (e1,), ((2,2),))
sink = expander_rewrite(con)
assert sink.op is UOps.VECTORIZE and len(sink.src) == 2
assert sink.op is Ops.VECTORIZE and len(sink.src) == 2
assert sink.src[0] == sink.src[1]
def test_contract_half_expand(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
con = UOp(UOps.CONTRACT, dtypes.int.vec(8), (e1,), ((1,4), (2,2)))
e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
con = UOp(Ops.CONTRACT, dtypes.int.vec(8), (e1,), ((1,4), (2,2)))
sink = expander_rewrite(con)
assert sink.op is UOps.VCONST and len(sink.arg) == 8
assert sink.op is Ops.VCONST and len(sink.arg) == 8
assert sink.arg[0] == sink.arg[1]
assert sink.arg[0] != sink.arg[2]
assert sink.arg[6] == sink.arg[7]
def test_expand_same_axis(self):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((1,4),))
e2 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
sink = expander_rewrite(e1+e2)
self.assertEqual(sink.op, UOps.EXPAND)
self.assertEqual(sink.src[0].op, UOps.VCONST)
self.assertEqual(sink.op, Ops.EXPAND)
self.assertEqual(sink.src[0].op, Ops.VCONST)
self.assertTupleEqual(sink.src[0].arg, (0,5,10,15))
def test_expand_different_axis(self, flip=False):
e1 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((2,4),))
e1 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(4*x for x in range(4))),), ((1,4),))
e2 = UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(4), tuple(x for x in range(4))),), ((2,4),))
sink = expander_rewrite((e2+e1) if flip else (e1+e2))
assert sink.op is UOps.EXPAND and len(sink.src[0].arg) == 16
assert sink.op is Ops.EXPAND and len(sink.src[0].arg) == 16
assert sink.arg == ((1, 4), (2, 4))
self.assertTupleEqual(sink.src[0].arg, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
@@ -540,47 +540,47 @@ class TestExpander(unittest.TestCase):
@unittest.skip("no longer supported")
def test_reduce_known_axis(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
sink = UOp(UOps.REDUCE, dtypes.int, (3*e1,e1), BinaryOps.ADD)
e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
sink = UOp(Ops.REDUCE, dtypes.int, (3*e1,e1), BinaryOps.ADD)
sink = expander_rewrite(sink)
assert sink.op is UOps.CONST
assert sink.op is Ops.CONST
self.assertEqual(sink.arg, 3*(0+1+2+3))
@unittest.skip("no longer supported")
def test_reduce_const(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
sink = UOp(UOps.REDUCE, dtypes.int, (UOp.const(dtypes.int, 3), e1), BinaryOps.ADD)
e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
sink = UOp(Ops.REDUCE, dtypes.int, (UOp.const(dtypes.int, 3), e1), BinaryOps.ADD)
sink = expander_rewrite(sink)
assert sink.op is UOps.CONST
assert sink.op is Ops.CONST
self.assertEqual(sink.arg, 3*4)
@unittest.skip("no longer supported")
def test_double_expand(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((2,4),))
e = UOp(UOps.EXPAND, dtypes.int, (e1, e2), ((1,2),))
e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
e2 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((2,4),))
e = UOp(Ops.EXPAND, dtypes.int, (e1, e2), ((1,2),))
sink = expander_rewrite(e)
assert sink.op is UOps.EXPAND and len(sink.src) == 8
assert sink.op is Ops.EXPAND and len(sink.src) == 8
assert sink.arg == ((1, 2), (2, 4))
self.assertListEqual([x.arg for x in sink.src], [0,1,2,3,4,5,6,7])
@unittest.skip("no longer supported")
def test_double_expand_reverse(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,4),))
e = UOp(UOps.EXPAND, dtypes.int, (e1, e2), ((2,2),))
e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
e2 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,4),))
e = UOp(Ops.EXPAND, dtypes.int, (e1, e2), ((2,2),))
sink = expander_rewrite(e)
assert sink.op is UOps.EXPAND and len(sink.src) == 8
assert sink.op is Ops.EXPAND and len(sink.src) == 8
assert sink.arg == ((1, 4), (2, 2))
self.assertListEqual([x.arg for x in sink.src], [0, 4, 1, 5, 2, 6, 3, 7])
@unittest.skip("no longer supported")
def test_double_expand_middle(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,2),(3,2)))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,2),(3,2)))
e = UOp(UOps.EXPAND, dtypes.int, (e1, e2), ((2,2),))
e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,2),(3,2)))
e2 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, 4+x) for x in range(4)), ((1,2),(3,2)))
e = UOp(Ops.EXPAND, dtypes.int, (e1, e2), ((2,2),))
sink = expander_rewrite(e)
assert sink.op is UOps.EXPAND and len(sink.src) == 8
assert sink.op is Ops.EXPAND and len(sink.src) == 8
assert sink.arg == ((1, 2), (2, 2), (3, 2))
self.assertListEqual([x.arg for x in sink.src], [0, 1, 4, 5, 2, 3, 6, 7])
@@ -588,106 +588,106 @@ class TestExpander(unittest.TestCase):
@unittest.expectedFailure
@unittest.skip
def test_reduce_different_axis(self):
e1 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
e2 = UOp(UOps.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
sink = UOp(UOps.REDUCE, dtypes.int, (e1,e2), BinaryOps.ADD)
e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),))
e2 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),))
sink = UOp(Ops.REDUCE, dtypes.int, (e1,e2), BinaryOps.ADD)
sink = expander_rewrite(sink)
print(sink)
class TestLoadStoreFolder(unittest.TestCase):
def test_simple_load_fold(self):
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(4)]
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
load = [UOp(Ops.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(4)]
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
sink = float4_rewrite(sink.sink())
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 1
def test_two_load_fold(self):
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(8)]
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
load = [UOp(Ops.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(8)]
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
sink = float4_rewrite(sink.sink())
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 2
assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 2
def test_simple_load_fold_gated(self):
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
gate = UOp(UOps.DEFINE_VAR, dtypes.bool)
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)]
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
gate = UOp(Ops.DEFINE_VAR, dtypes.bool)
load = [UOp(Ops.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)]
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
sink = float4_rewrite(sink.sink())
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
single_load = [x for x in sink.sparents if x.op is UOps.LOAD][0]
self.assertEqual(single_load.src[1].op, UOps.VECTORIZE)
assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 1
single_load = [x for x in sink.sparents if x.op is Ops.LOAD][0]
self.assertEqual(single_load.src[1].op, Ops.VECTORIZE)
def test_simple_load_dont_fold_different_gated(self):
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
gate = UOp.variable("g1", False, True, dtypes.bool)
gate2 = UOp.variable("g2", False, True, dtypes.bool)
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate if i == 0 else gate2)) for i in range(4)]
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
load = [UOp(Ops.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate if i == 0 else gate2)) for i in range(4)]
sink = UOp(Ops.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
sink = float4_rewrite(sink.sink())
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 3
assert len([x for x in sink.sparents if x.op is Ops.LOAD]) == 3
def test_simple_store_fold(self):
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0))) for i in range(4)]
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
load = [UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0))) for i in range(4)]
sink = UOp(Ops.SINK, dtypes.void, tuple(load))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1
assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 1
def test_simple_store_fold_gate(self):
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
gate = UOp.variable("g1", False, True, dtypes.bool)
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)]
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
load = [UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, 0), gate)) for i in range(4)]
sink = UOp(Ops.SINK, dtypes.void, tuple(load))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1
one_store = [x for x in sink.sparents if x.op is UOps.STORE][0]
assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 1
one_store = [x for x in sink.sparents if x.op is Ops.STORE][0]
assert len(one_store.src) == 3
assert str(one_store.src[2]) == str(gate) # huh, why do i need str here?
def test_simple_store_dont_fold(self):
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
gate = UOp.variable("g1", False, True, dtypes.bool)
gate2 = UOp.variable("g2", False, True, dtypes.bool)
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
load = [UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
sink = UOp(Ops.SINK, dtypes.void, tuple(load))
sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 3
assert len([x for x in sink.sparents if x.op is Ops.STORE]) == 3
class TestIFUOps(unittest.TestCase):
def test_create_ifs(self):
gbuf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
sbuf = UOp(UOps.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 4))
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 4))
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 4))
gate = valid&(lidx.ne(2))
idx = UOp.const(dtypes.int, 0)
st = UOp(UOps.STORE, dtypes.void, (sbuf, idx, UOp.const(dtypes.float, 42)))
barrier = UOp(UOps.BARRIER, dtypes.void, (st,))
lbuf = UOp(UOps.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, 0), barrier))
store = UOp(UOps.STORE, dtypes.void, (gbuf, UOp.const(dtypes.int, 0), lbuf, gate))
sink = UOp(UOps.SINK, dtypes.void, (store,))
st = UOp(Ops.STORE, dtypes.void, (sbuf, idx, UOp.const(dtypes.float, 42)))
barrier = UOp(Ops.BARRIER, dtypes.void, (st,))
lbuf = UOp(Ops.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, 0), barrier))
store = UOp(Ops.STORE, dtypes.void, (gbuf, UOp.const(dtypes.int, 0), lbuf, gate))
sink = UOp(Ops.SINK, dtypes.void, (store,))
sink = full_graph_rewrite(sink)
if_uops = [u for u in sink.parents if u.op is UOps.IF]
if_uops = [u for u in sink.parents if u.op is Ops.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
for st in sink.src:
self.assertEqual(len(st.src), 2)
def test_expand_ifs_one_gate(self):
gbuf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
sbuf = UOp(UOps.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 16))
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1)
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 16))
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1)
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 16))
gate = valid&(lidx.ne(2))
st = UOp(UOps.STORE, dtypes.void, (sbuf, lidx, UOp.const(dtypes.float, 42)))
barrier = UOp(UOps.BARRIER, dtypes.void, (st,))
lbufs = [UOp(UOps.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, i), barrier)) for i in range(4)]
stores = [UOp(UOps.STORE, dtypes.void, (gbuf, UOp.const(dtypes.int, i), lbufs[i], gate)) for i in range(4)]
sink = UOp(UOps.SINK, dtypes.void, tuple(stores))
st = UOp(Ops.STORE, dtypes.void, (sbuf, lidx, UOp.const(dtypes.float, 42)))
barrier = UOp(Ops.BARRIER, dtypes.void, (st,))
lbufs = [UOp(Ops.LOAD, dtypes.float, (sbuf, UOp.const(dtypes.int, i), barrier)) for i in range(4)]
stores = [UOp(Ops.STORE, dtypes.void, (gbuf, UOp.const(dtypes.int, i), lbufs[i], gate)) for i in range(4)]
sink = UOp(Ops.SINK, dtypes.void, tuple(stores))
sink = full_graph_rewrite(sink)
if_uops = [u for u in sink.parents if u.op is UOps.IF]
if_uops = [u for u in sink.parents if u.op is Ops.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
for st in sink.src:
@@ -696,14 +696,14 @@ class TestIFUOps(unittest.TestCase):
# this will be fixed with the merge gated stores bounty
@unittest.expectedFailure
def test_expand_ifs_dumb(self):
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
valid = UOp(Ops.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
lidx = UOp(Ops.SPECIAL, dtypes.int, (), ("lidx0", 4))
gate = valid&(lidx.ne(2))
stores = [UOp(UOps.STORE, dtypes.void, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
sink = UOp(UOps.SINK, dtypes.void, tuple(stores))
stores = [UOp(Ops.STORE, dtypes.void, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
sink = UOp(Ops.SINK, dtypes.void, tuple(stores))
sink = full_graph_rewrite(sink)
if_uops = [u for u in sink.parents if u.op is UOps.IF]
if_uops = [u for u in sink.parents if u.op is Ops.IF]
self.assertEqual(len(if_uops), 1)
self.assertEqual(if_uops[0].src[0], gate)
for st in sink.src:

View File

@@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.helpers import CI, DEBUG, getenv, Context
from tinygrad.dtype import dtypes, DType
from tinygrad.device import Buffer, Device
from tinygrad.ops import UOps, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
from tinygrad.ops import Ops, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
from tinygrad.renderer import Program
from tinygrad.engine.schedule import create_schedule, to_si
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
@@ -23,18 +23,18 @@ def _uops_to_prg(uops_list):
return CompiledRunner(Program("test", src, Device.DEFAULT, uops=uops,
global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None))
def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], src:Tuple[UOp, ...], arg:Any=None) -> UOp:
def uop(uops:List[UOp], uop:Ops, dtype:Optional[DType], src:Tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(src), arg))
return uops[-1]
def _test_single_value(vals, op, dts):
uops = []
output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1]
buf_store = uop(uops, UOps.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)]
loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts))
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
buf_loads = [uop(uops, Ops.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)]
loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i], uop(uops, Ops.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts))
alu = uop(uops, Ops.ALU, output_dtype, loads, op)
out = uop(uops, Ops.STORE, dtypes.void, (buf_store, uop(uops, Ops.CONST, dtypes.int32, (), 0), alu))
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=_to_np_dtype(dtype)).data) for a,dtype in zip(vals, dts)]
prg = _uops_to_prg([out])
@@ -46,10 +46,10 @@ def _test_single_value(vals, op, dts):
def _test_single_value_const(vals, op, dts):
uops = []
output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1]
buf_store = uop(uops, UOps.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
loads = (uop(uops, Ops.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
alu = uop(uops, Ops.ALU, output_dtype, loads, op)
out = uop(uops, Ops.STORE, dtypes.void, (buf_store, uop(uops, Ops.CONST, dtypes.int32, (), 0), alu))
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
prg = _uops_to_prg([out])
prg.exec([buf])
@@ -59,9 +59,9 @@ def _test_single_value_const(vals, op, dts):
def _test_uops_result(output_dtype, uops, res):
# uops = []
buf_store = uop(uops, UOps.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
# res = output_fn(uops)
out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), res))
out = uop(uops, Ops.STORE, dtypes.void, (buf_store, uop(uops, Ops.CONST, dtypes.int32, (), 0), res))
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
prg = _uops_to_prg([out])
prg.exec([buf])
@@ -244,63 +244,63 @@ class TestConstantFolding(unittest.TestCase):
si = create_schedule([t.lazydata])
assert len(si) == 1
ji = lower_schedule_item(si[-1])
assert any(uop.op is UOps.BITCAST for uop in ji.prg.p.uops), f"{[uop.op for uop in ji.prg.p.uops]} does not contain bitcast"
assert any(uop.op is Ops.BITCAST for uop in ji.prg.p.uops), f"{[uop.op for uop in ji.prg.p.uops]} does not contain bitcast"
class TestGatedStoreRewrite(unittest.TestCase):
@unittest.expectedFailure
def test_tiny_gate_store(self):
gmem = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
gmem = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
idx = gidx0 * UOp.const(dtypes.int, 2)
val = UOp.const(dtypes.float, 42.0)
gate = gidx0.lt(UOp.const(dtypes.int, 1))
store = UOp(UOps.STORE, dtypes.void, (gmem, idx, val, gate))
store = UOp(Ops.STORE, dtypes.void, (gmem, idx, val, gate))
uops = to_uops_list([store])
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
if_uop = next(u for u in uops if u.op is UOps.IF)
endif = next(u for u in uops if u.op is UOps.ENDIF)
if_uop = next(u for u in uops if u.op is Ops.IF)
endif = next(u for u in uops if u.op is Ops.ENDIF)
assert endif.src[0] is if_uop
gated_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
self.assertEqual(len(gated_uops), 1)
self.assertIs(gated_uops[-1].op, UOps.STORE)
self.assertIs(gated_uops[-1].op, Ops.STORE)
@unittest.expectedFailure
def test_gate_some_stores(self):
gmem0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
gmem1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
gmem0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
idx = gidx0*UOp.const(dtypes.int, 2)
val = UOp.const(dtypes.float, 42.0)
gate = gidx0.lt(UOp.const(dtypes.int, 1))
stores = [UOp.store(gmem0, idx, val, gate), UOp.store(gmem1, idx, val)]
uops = linearize_uop(stores)
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
if_uop = next(u for u in uops if u.op is UOps.IF)
endif = next(u for u in uops if u.op is UOps.ENDIF)
if_uop = next(u for u in uops if u.op is Ops.IF)
endif = next(u for u in uops if u.op is Ops.ENDIF)
assert endif.src[0] is if_uop
gated_uops = tuple(uops.uops[uops.uops.index(if_uop)+1:uops.uops.index(endif)])
self.assertEqual(len(gated_uops), 1)
self.assertIs(gated_uops[-1].op, UOps.STORE)
self.assertIs(gated_uops[-1].op, Ops.STORE)
# scaled down version of TestLinearizerDumb.test_unmerged_ifs
@unittest.expectedFailure
def test_merge_ifs_alt(self):
gmem0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
gmem1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
gmem0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
gmem1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
idx = gidx0*UOp.const(dtypes.int, 2)
val = UOp.const(dtypes.float, 42.0)
gate = gidx0.lt(UOp.const(dtypes.int, 1))
stores = [UOp.store(gmem0, idx, val, gate), UOp.store(gmem1, idx, val, gate)]
uops = linearize_uop(stores)
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
ifs = [u for u in uops if u.op is UOps.IF]
endifs = [u for u in uops if u.op is UOps.ENDIF]
ifs = [u for u in uops if u.op is Ops.IF]
endifs = [u for u in uops if u.op is Ops.ENDIF]
self.assertEqual(len(ifs), 1)
self.assertEqual(len(endifs), 1)
gated_uops = tuple(uops.uops[uops.uops.index(ifs[0])+1:uops.uops.index(endifs[0])])
self.assertEqual(len(gated_uops), 2)
for x in gated_uops: self.assertIs(x.op, UOps.STORE)
for x in gated_uops: self.assertIs(x.op, Ops.STORE)
class TestLocalAccess(unittest.TestCase):
# NOTE: this is failing on METAL CI, no idea why. Works locally.
@@ -308,44 +308,44 @@ class TestLocalAccess(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
def test_local_basic(self):
uops = []
smem = uop(uops, UOps.DEFINE_LOCAL, dtypes.float32.ptr(local=True), (), ('smem', 16))
st = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), uop(uops, UOps.CONST, dtypes.float32, (), 42.0)))
barr = uop(uops, UOps.BARRIER, dtypes.void, (st,))
sres = uop(uops, UOps.LOAD, dtypes.float32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), barr))
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.float32.ptr(local=True), (), ('smem', 16))
st = uop(uops, Ops.STORE, dtypes.void, (smem, uop(uops, Ops.CONST, dtypes.int32, (), 0), uop(uops, Ops.CONST, dtypes.float32, (), 42.0)))
barr = uop(uops, Ops.BARRIER, dtypes.void, (st,))
sres = uop(uops, Ops.LOAD, dtypes.float32, (smem, uop(uops, Ops.CONST, dtypes.int32, (), 0), barr))
self.assertEqual(_test_uops_result(dtypes.float32, uops, sres), 42)
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
def test_local_indirect(self):
uops = []
smem = uop(uops, UOps.DEFINE_LOCAL, dtypes.int32.ptr(local=True), (), ('smem', 16))
st1 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), uop(uops, UOps.CONST, dtypes.int32, (), 2)))
st2 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 2), uop(uops, UOps.CONST, dtypes.int32, (), 42)))
barr = uop(uops, UOps.BARRIER, dtypes.void, (st1,st2))
ofs = uop(uops, UOps.LOAD, dtypes.int32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), barr))
sres = uop(uops, UOps.LOAD, dtypes.int32, (smem, ofs))
smem = uop(uops, Ops.DEFINE_LOCAL, dtypes.int32.ptr(local=True), (), ('smem', 16))
st1 = uop(uops, Ops.STORE, dtypes.void, (smem, uop(uops, Ops.CONST, dtypes.int32, (), 1), uop(uops, Ops.CONST, dtypes.int32, (), 2)))
st2 = uop(uops, Ops.STORE, dtypes.void, (smem, uop(uops, Ops.CONST, dtypes.int32, (), 2), uop(uops, Ops.CONST, dtypes.int32, (), 42)))
barr = uop(uops, Ops.BARRIER, dtypes.void, (st1,st2))
ofs = uop(uops, Ops.LOAD, dtypes.int32, (smem, uop(uops, Ops.CONST, dtypes.int32, (), 1), barr))
sres = uop(uops, Ops.LOAD, dtypes.int32, (smem, ofs))
self.assertEqual(_test_uops_result(dtypes.int32, uops, sres), 42)
@unittest.skipUnless(getenv("PTX"), "This only tests assembly backends")
class TestAssembly(unittest.TestCase):
def test_bitshift_left(self):
g1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
c1 = UOp(UOps.CONST, dtypes.int, (), 2)
c2 = UOp(UOps.CONST, dtypes.int, (), 3)
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.MUL)
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.MUL)
g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
c1 = UOp(Ops.CONST, dtypes.int, (), 2)
c2 = UOp(Ops.CONST, dtypes.int, (), 3)
l1 = UOp(Ops.LOAD, dtypes.int, (g1, c1))
a1 = UOp(Ops.ALU, dtypes.int, (l1, c1), BinaryOps.MUL)
a2 = UOp(Ops.ALU, dtypes.int, (l1, c2), BinaryOps.MUL)
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render("test", uops)
self.assertEqual(uops[-1].arg, BinaryOps.SHL)
self.assertEqual(uops[-2].arg, BinaryOps.MUL)
def test_bitshift_right(self):
g1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
c1 = UOp(UOps.CONST, dtypes.int, (), 2)
c2 = UOp(UOps.CONST, dtypes.int, (), 3)
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV)
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV)
g1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
c1 = UOp(Ops.CONST, dtypes.int, (), 2)
c2 = UOp(Ops.CONST, dtypes.int, (), 3)
l1 = UOp(Ops.LOAD, dtypes.int, (g1, c1))
a1 = UOp(Ops.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV)
a2 = UOp(Ops.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV)
uops = to_uops_list([a1,a2], opts=Device[Device.DEFAULT].renderer)
Device[Device.DEFAULT].renderer.render("test", uops)
self.assertEqual(uops[-1].arg, BinaryOps.SHR)
@@ -354,38 +354,38 @@ class TestAssembly(unittest.TestCase):
class TestUOpMethod(unittest.TestCase):
@unittest.skip("uops lt no longer ordered")
def test_compare_alu_same_src_different_arg(self):
a = UOp(UOps.CONST, dtypes.float, (), 2.0)
b = UOp(UOps.CONST, dtypes.float, (), 3.0)
a = UOp(Ops.CONST, dtypes.float, (), 2.0)
b = UOp(Ops.CONST, dtypes.float, (), 3.0)
add = UOp(UOps.ALU, dtypes.float, (a, b), BinaryOps.ADD)
mul = UOp(UOps.ALU, dtypes.float, (a, b), BinaryOps.MUL)
add = UOp(Ops.ALU, dtypes.float, (a, b), BinaryOps.ADD)
mul = UOp(Ops.ALU, dtypes.float, (a, b), BinaryOps.MUL)
assert (add < mul) or (mul < add), "add and mul with same src should have an order"
def test_uop_variables(self):
a = UOp.variable("a", 1, 10)
uop_var = UOp.const(dtypes.int, a)
st_var = UOp(UOps.LOAD, dtypes.float, (UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0),
st_var = UOp(Ops.LOAD, dtypes.float, (UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0),
ShapeTracker.from_shape((2, a)).to_uop()))
ast_vars = (st_var+uop_var).variables()
self.assertEqual(len(ast_vars), 1)
self.assertEqual(ast_vars[0], a)
def test_const_factor(self):
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 8))
self.assertEqual(UOp(UOps.CONST, dtypes.int, (), 17).const_factor(), 17)
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 8))
self.assertEqual(UOp(Ops.CONST, dtypes.int, (), 17).const_factor(), 17)
self.assertEqual(gidx0.const_factor(), 1)
self.assertEqual((gidx0*3).const_factor(), 3)
self.assertEqual((gidx0*3+6).const_factor(), 3)
self.assertEqual((gidx0*3+1).const_factor(), 1)
def test_replace(self):
x = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
x = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
self.assertIs(x.replace(arg=None).arg, None)
with self.assertRaises(AssertionError): x.replace(field="a")
class TestUOpStr(unittest.TestCase):
def test_uop_str(self):
a = UOp(UOps.CONST, dtypes.float, (), 2.0) + UOp(UOps.CONST, dtypes.float, (), 3.0)
a = UOp(Ops.CONST, dtypes.float, (), 2.0) + UOp(Ops.CONST, dtypes.float, (), 3.0)
for _ in range(20): a = a + a
assert len(str(a)) < 10_000, "exponential string growth"
assert str(eval(str(a))) == str(a)
@@ -394,11 +394,11 @@ class TestUOpStr(unittest.TestCase):
t = t + t * Tensor.rand(10)
# nice big complicated uop
with Context(NOOPT=1):
sink = UOp(UOps.SINK, dtypes.void, (get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops[-1],))
sink = UOp(Ops.SINK, dtypes.void, (get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops[-1],))
self.assertEqual(sink, eval(str(sink)))
def test_vectorized_str(self):
vec = UOp(UOps.VECTORIZE, dtypes.int.vec(4), tuple(UOp.const(dtypes.int, x) for x in range(4)))
vec = UOp(Ops.VECTORIZE, dtypes.int.vec(4), tuple(UOp.const(dtypes.int, x) for x in range(4)))
assert str(eval(str(vec))) == str(vec)
@unittest.skip("uop no longer has order like this")
@@ -406,23 +406,23 @@ class TestIndexingOrdering(unittest.TestCase):
# NOTE: these tests skip type_verify since they add dtype to STORE
@unittest.expectedFailure
def test_simple_order(self):
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
st0 = UOp(Ops.STORE, dtypes.float.vec(4), (buf, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
st1 = UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
uops = to_uops_list([st1, st0], skip_check=True)
stores = [st for st in uops if st.op is UOps.STORE]
stores = [st for st in uops if st.op is Ops.STORE]
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
@unittest.expectedFailure
def test_ordering_multi_output(self):
buf0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
buf1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
st0_0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf0, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
st1_0 = UOp(UOps.STORE, dtypes.float, (buf0, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
st0_1 = UOp(UOps.STORE, dtypes.float.vec(4), (buf1, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
st1_1 = UOp(UOps.STORE, dtypes.float, (buf1, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
buf0 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
buf1 = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
st0_0 = UOp(Ops.STORE, dtypes.float.vec(4), (buf0, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
st1_0 = UOp(Ops.STORE, dtypes.float, (buf0, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
st0_1 = UOp(Ops.STORE, dtypes.float.vec(4), (buf1, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
st1_1 = UOp(Ops.STORE, dtypes.float, (buf1, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
uops = to_uops_list([st0_0, st1_0, st0_1, st1_1], skip_check=True)
stores = [st for st in uops if st.op is UOps.STORE]
stores = [st for st in uops if st.op is Ops.STORE]
print("\n".join(map(str, stores)))
# buf0 stores come first
self.assertEqual(stores[0].src[0].arg, stores[1].src[0].arg)
@@ -433,12 +433,12 @@ class TestIndexingOrdering(unittest.TestCase):
assert stores[2].src[1] < stores[3].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
def test_simple_order_with_special(self):
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
gidx0 = UOp(Ops.SPECIAL, dtypes.int, (), ('gidx0', 4))
st0 = UOp(Ops.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
st1 = UOp(Ops.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
uops = linearize_uop(UOp.sink(st1, st0), skip_check=True)
stores = [st for st in uops if st.op is UOps.STORE]
stores = [st for st in uops if st.op is Ops.STORE]
assert stores[0].src[1] < stores[1].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
class TestUPatHelpers(unittest.TestCase):
@@ -447,7 +447,7 @@ class TestUPatHelpers(unittest.TestCase):
self.assertEqual(to_si.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "schedule.py")
self.assertEqual(spec.patterns[0][0].location[0].replace("\\", "/").split("/")[-1], "ops.py")
with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*?
test_upat = UPat(UOps.CONST, dtypes.bool)
test_upat = UPat(Ops.CONST, dtypes.bool)
self.assertEqual(test_upat.location[0].split("/")[-1], __file__.replace("\\", "/").split("/")[-1])
if __name__ == '__main__':

View File

@@ -4,7 +4,7 @@ from tinygrad.helpers import getenv, GlobalCounters
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import lower_schedule_item
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.ops import BinaryOps, TernaryOps, flops_mem, UOps, UOp
from tinygrad.ops import BinaryOps, TernaryOps, flops_mem, Ops, UOp
from tinygrad.dtype import dtypes
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
@@ -119,23 +119,23 @@ class TestUOpsStats(unittest.TestCase):
#MULACC should have the same stats as MUL + ADD
def test_mulacc(self):
globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
u1 = UOp(UOps.LOAD, dtypes.int, (globl.index(o1),))
u2 = UOp(UOps.LOAD, dtypes.int, (globl.index(o2),))
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2)
u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),))
u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),))
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
u4 = UOp(Ops.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
u5 = UOp(Ops.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
uops = linearize_uop(u5.sink())
globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
u1 = UOp(UOps.LOAD, dtypes.int, (globl.index(o1),))
u2 = UOp(UOps.LOAD, dtypes.int, (globl.index(o2),))
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
o2 = UOp(Ops.CONST, dtypes.int, tuple(), 2)
u1 = UOp(Ops.LOAD, dtypes.int, (globl.index(o1),))
u2 = UOp(Ops.LOAD, dtypes.int, (globl.index(o2),))
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
u4 = UOp(Ops.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
uops_fma = linearize_uop(u4.sink())
self.assertEqual(flops_mem(uops), flops_mem(uops_fma))

View File

@@ -1,7 +1,7 @@
from typing import Dict, List, Optional
import unittest
from tinygrad.dtype import dtypes
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, UOps, UPat, \
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, \
graph_rewrite, contexts, track_rewrites
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json
@@ -27,7 +27,7 @@ class TestViz(unittest.TestCase):
pm = PatternMatcher([
(UPat.var("x")*1, lambda x:x),
])
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
uops = helper_test_viz(a*1, pm)
self.assertEqual(len(uops), 1)
self.assertEqual(uops[0], a)
@@ -37,21 +37,21 @@ class TestViz(unittest.TestCase):
(UPat.var("x")+UPat.var("x"), lambda x:x*2),
(UPat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))),
])
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
uops = helper_test_viz(a+a, pm)
self.assertEqual(len(uops), 2)
self.assertEqual(uops[0], a*2)
self.assertEqual(uops[1], graph_rewrite(a+a, pm))
def test_rewrite_with_ctx(self):
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
b = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1), UOp.const(dtypes.int, 0)))
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
b = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1), UOp.const(dtypes.int, 0)))
def store_load(ctx:Dict[UOp, None], x:UOp) -> Optional[UOp]:
if x in ctx: return None
ctx[x] = None
return UOp.store(*x.src, x)
pm = PatternMatcher([
(UPat(UOps.LOAD, name="x"), store_load),
(UPat(Ops.LOAD, name="x"), store_load),
])
uops = helper_test_viz(a+b, pm, {})
self.assertEqual(len(uops), 2)
@@ -61,7 +61,7 @@ class TestViz(unittest.TestCase):
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
@track_rewrites(named=True)
def do_rewrite(x:UOp): return graph_rewrite(x, simple)
ld = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
do_rewrite(ld*1)
do_rewrite(ld*2)
ret = get_metadata(contexts)
@@ -79,13 +79,13 @@ class TestViz(unittest.TestCase):
def do_rewrite(x:UOp):
x = graph_rewrite(x, simple) # NOTE: viz tracks this
raise Exception("test")
ld = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
with self.assertRaises(Exception): do_rewrite(ld*1)
ret = get_metadata(contexts)
self.assertEqual(len(ret), 1)
def test_fold_const(self):
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
graph = uop_to_json(a)
assert not any(v[0].startswith("CONST") for v in graph.values())
assert len([x for x in graph.values() if "CONST" in x[0]]) == 1

View File

@@ -1,6 +1,6 @@
import unittest
from tinygrad import Tensor, GlobalCounters
from tinygrad.ops import UOps
from tinygrad.ops import Ops
from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.schedule import create_schedule
@@ -23,7 +23,7 @@ class TestWinograd(unittest.TestCase):
sched = create_schedule([out.lazydata])
for i,s in enumerate(sched):
if s.ast.op is not UOps.SINK: continue
if s.ast.op is not Ops.SINK: continue
ops = s.ast.parents
with Timing(f"linearize {i} with {len(ops):4d} ops: "):
l = Kernel(s.ast)

View File

@@ -1,7 +1,7 @@
import unittest, math
from tinygrad import dtypes
from tinygrad.helpers import all_same
from tinygrad.ops import UOp, UOps, BinaryOps, exec_alu
from tinygrad.ops import UOp, Ops, BinaryOps, exec_alu
from tinygrad.codegen.uopgraph import full_graph_rewrite
# Helper function to apply the graph rewrite
@@ -9,12 +9,12 @@ def apply_rewrite(expr):
return full_graph_rewrite(expr.sink()).src[0]
def evaluate_uop(uop, variables):
if uop.op == UOps.CONST:
if uop.op == Ops.CONST:
return uop.arg
elif uop.op == UOps.DEFINE_VAR:
elif uop.op == Ops.DEFINE_VAR:
var_name = uop.arg[0]
return variables[var_name]
elif uop.op == UOps.ALU:
elif uop.op == Ops.ALU:
src_values = [evaluate_uop(src, variables) for src in uop.src]
return exec_alu(uop.arg, uop.dtype, src_values)
else:
@@ -23,12 +23,12 @@ def evaluate_uop(uop, variables):
class TestArithmeticSimplifications(unittest.TestCase):
def test_full_graph_rewrite_division_by_zero(self):
optimized_div_uop = apply_rewrite(UOp.const(dtypes.float32, 10.0) / UOp.const(dtypes.float32, 0.0))
self.assertEqual(optimized_div_uop.op, UOps.CONST)
self.assertEqual(optimized_div_uop.op, Ops.CONST)
self.assertTrue(math.isinf(optimized_div_uop.arg) or math.isnan(optimized_div_uop.arg))
def test_full_graph_rewrite_redundant_operations(self):
optimized_uop = apply_rewrite((UOp.const(dtypes.float32, 10.0) + UOp.const(dtypes.float32, 0.0)) * UOp.const(dtypes.float32, 1.0))
self.assertEqual(optimized_uop.op, UOps.CONST)
self.assertEqual(optimized_uop.op, Ops.CONST)
self.assertEqual(optimized_uop.arg, 10.0)
def test_full_graph_rewrite_large_graph(self):
@@ -36,17 +36,17 @@ class TestArithmeticSimplifications(unittest.TestCase):
for i in range(1, 101):
prev_uop += UOp.const(dtypes.int32, i)
optimized_uop = apply_rewrite(prev_uop)
self.assertEqual(optimized_uop.op, UOps.CONST)
self.assertEqual(optimized_uop.op, Ops.CONST)
self.assertEqual(optimized_uop.arg, sum(range(1, 101)))
def test_full_graph_rewrite_division_by_one(self):
optimized_uop = apply_rewrite(UOp.const(dtypes.float32, 42.0) / UOp.const(dtypes.float32, 1.0))
self.assertEqual(optimized_uop.op, UOps.CONST)
self.assertEqual(optimized_uop.op, Ops.CONST)
self.assertEqual(optimized_uop.arg, 42.0)
def test_full_graph_rewrite_modulo_by_one(self):
optimized_uop = apply_rewrite(UOp.const(dtypes.int32, 42) % UOp.const(dtypes.int32, 1))
self.assertEqual(optimized_uop.op, UOps.CONST)
self.assertEqual(optimized_uop.op, Ops.CONST)
self.assertEqual(optimized_uop.arg, 0)
@@ -90,7 +90,7 @@ class TestFoldingAndReduction(unittest.TestCase):
inner_range = UOp.range(dtypes.int32, 0, 4, 1)
expr = (outer_range * 10) + inner_range
optimized_reduce_uop = apply_rewrite(expr.reduce(BinaryOps.ADD, outer_range, inner_range))
self.assertEqual(optimized_reduce_uop.op, UOps.CONST)
self.assertEqual(optimized_reduce_uop.op, Ops.CONST)
self.assertEqual(optimized_reduce_uop.arg, sum((i * 10) + j for i in range(8) for j in range(4)))
@@ -98,32 +98,32 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
def test_full_graph_rewrite_modulo_folding_with_define_var(self):
x_var_uop = UOp.variable('x', 0, 100)
optimized_mod_uop = apply_rewrite(((x_var_uop * 4) + 2) % 4)
self.assertEqual(optimized_mod_uop.op, UOps.CONST)
self.assertEqual(optimized_mod_uop.op, Ops.CONST)
self.assertEqual(optimized_mod_uop.arg, 2)
def test_full_graph_rewrite_division_folding_with_define_var(self):
n_var_uop = UOp.variable('n', 1, 1000)
optimized_div_uop = apply_rewrite((n_var_uop * 6) // 3)
self.assertEqual(optimized_div_uop.op, UOps.ALU)
self.assertEqual(optimized_div_uop.op, Ops.ALU)
self.assertEqual(optimized_div_uop.arg, BinaryOps.MUL)
self.assertEqual(optimized_div_uop.src[1].arg, 2)
def test_full_graph_rewrite_complex_mod_div_folding(self):
k_var_uop = UOp.variable('k', 0, 50)
optimized_div_uop = apply_rewrite(((k_var_uop * 12 + 8) % 6) // 2)
self.assertEqual(optimized_div_uop.op, UOps.CONST)
self.assertEqual(optimized_div_uop.op, Ops.CONST)
self.assertEqual(optimized_div_uop.arg, 1)
def test_graph_rewrite_div_folding_bug(self):
lhs = UOp(UOps.ALU, dtypes.int.vec(4), arg=BinaryOps.ADD, src=(
UOp(UOps.VECTORIZE, dtypes.int.vec(4), arg=None, src=(UOp(UOps.SPECIAL, dtypes.int, arg=('lidx0', 32), src=()),)*4),
UOp(UOps.VCONST, dtypes.int.vec(4), arg=(0, 256, 512, 768), src=())))
lhs = UOp(Ops.ALU, dtypes.int.vec(4), arg=BinaryOps.ADD, src=(
UOp(Ops.VECTORIZE, dtypes.int.vec(4), arg=None, src=(UOp(Ops.SPECIAL, dtypes.int, arg=('lidx0', 32), src=()),)*4),
UOp(Ops.VCONST, dtypes.int.vec(4), arg=(0, 256, 512, 768), src=())))
rhs = UOp.const(dtypes.int.vec(4), 2)
unopt = lhs.lt(rhs)
opt = apply_rewrite(unopt)
print(unopt)
print(opt)
if opt.op is UOps.VECTORIZE: self.assertFalse(all_same(opt.src))
if opt.op is Ops.VECTORIZE: self.assertFalse(all_same(opt.src))
def test_full_graph_rewrite_modulo_large_divisor(self):
x_var_uop = UOp.variable('x', 1, 5)
@@ -180,12 +180,12 @@ class TestGEPAndVectorizeRewrite(unittest.TestCase):
def test_gep_on_vconst(self):
# GEP on a VCONST to extract a single element
vconst = UOp(UOps.VCONST, dtypes.float32.vec(4), arg=(1.0, 2.0, 3.0, 4.0))
vconst = UOp(Ops.VCONST, dtypes.float32.vec(4), arg=(1.0, 2.0, 3.0, 4.0))
self.assertEqual(apply_rewrite(vconst.gep(2)).arg, 3.0)
def test_gep_tuple_on_vconst(self):
# GEP on a VCONST using a tuple to extract multiple elements
vconst = UOp(UOps.VCONST, dtypes.float32.vec(4), arg=(7.0, 8.0, 9.0, 10.0))
vconst = UOp(Ops.VCONST, dtypes.float32.vec(4), arg=(7.0, 8.0, 9.0, 10.0))
optimized_uop = apply_rewrite(vconst.gep((1, 3)))
self.assertEqual([sub_uop.arg for sub_uop in optimized_uop.src], [8.0, 10.0])
@@ -198,7 +198,7 @@ class TestGEPAndVectorizeRewrite(unittest.TestCase):
def test_vectorize_multiple_elements(self):
# Vectorizing multiple elements using GEP
base_vector = UOp.const(dtypes.float32.vec(4), (5.0, 10.0, 15.0, 20.0))
vectorized_uop = UOp(UOps.VECTORIZE, dtypes.float32.vec(4), src=(base_vector.gep(0), base_vector.gep(1), base_vector.gep(2), base_vector.gep(3)))
vectorized_uop = UOp(Ops.VECTORIZE, dtypes.float32.vec(4), src=(base_vector.gep(0), base_vector.gep(1), base_vector.gep(2), base_vector.gep(3)))
optimized_uop = apply_rewrite(vectorized_uop)
self.assertEqual([sub_uop.arg for sub_uop in optimized_uop.src], [5.0, 10.0, 15.0, 20.0])

View File

@@ -1,13 +1,13 @@
import unittest, itertools
from tinygrad.dtype import dtypes
from tinygrad.ops import UOps, UOp, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401
from tinygrad.ops import Ops, UOp, BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401
from tinygrad.ops import PatternMatcher, UPat
class TestPatternMatcher(unittest.TestCase):
def test_simple_match(self):
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.int, arg=1)
matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype=dtypes.float), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.int, arg=1)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None)
@@ -30,9 +30,9 @@ class TestPatternMatcher(unittest.TestCase):
nonlocal match_cnt
match_cnt += 1
assert len(x.src) == 0
return UOp(UOps.CONST, src=(UOp(UOps.CONST),))
matcher = PatternMatcher([(UPat(UOps.CONST, src=(), name="x"), fxn)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
return UOp(Ops.CONST, src=(UOp(Ops.CONST),))
matcher = PatternMatcher([(UPat(Ops.CONST, src=(), name="x"), fxn)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
# second rewrite shouldn't match anything
c1 = matcher.rewrite(c1)
c1 = matcher.rewrite(c1)
@@ -42,9 +42,9 @@ class TestPatternMatcher(unittest.TestCase):
def fxn(ctx, x):
ctx.append(True)
assert len(x.src) == 0
return UOp(UOps.CONST, src=(UOp(UOps.CONST),))
matcher = PatternMatcher([(UPat(UOps.CONST, src=(), name="x"), fxn)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
return UOp(Ops.CONST, src=(UOp(Ops.CONST),))
matcher = PatternMatcher([(UPat(Ops.CONST, src=(), name="x"), fxn)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
# second rewrite shouldn't match anything
ctx = []
c1 = matcher.rewrite(c1, ctx)
@@ -52,33 +52,33 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(len(ctx), 1)
def test_uop(self):
matcher = PatternMatcher([(UPat(UOps.CONST, name="x"), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.ALU, dtypes.float, (c1, c1), BinaryOps.ADD)
matcher = PatternMatcher([(UPat(Ops.CONST, name="x"), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.ALU, dtypes.float, (c1, c1), BinaryOps.ADD)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None)
def test_uop_set(self):
matcher = PatternMatcher([(UPat({UOps.CONST, UOps.CAST}, name="x"), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.bool, arg=False)
c2 = UOp(UOps.CAST, dtypes.int, (c1,))
c3 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c4 = UOp(UOps.ALU, dtypes.float, (c3, c3), BinaryOps.ADD)
matcher = PatternMatcher([(UPat({Ops.CONST, Ops.CAST}, name="x"), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.bool, arg=False)
c2 = UOp(Ops.CAST, dtypes.int, (c1,))
c3 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c4 = UOp(Ops.ALU, dtypes.float, (c3, c3), BinaryOps.ADD)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), c2)
self.assertEqual(matcher.rewrite(c4), None)
def test_arg(self):
matcher = PatternMatcher([
(UPat(UOps.CONST, arg=0, name="x"), lambda x: x),
(UPat(UOps.CONST, arg=False, name="x"), lambda x: x),
(UPat(UOps.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x),
(UPat(Ops.CONST, arg=0, name="x"), lambda x: x),
(UPat(Ops.CONST, arg=False, name="x"), lambda x: x),
(UPat(Ops.ALU, arg=BinaryOps.MAX, name="x"), lambda x: x),
])
c1 = UOp(UOps.CONST, dtypes.float, arg=0.0)
c2 = UOp(UOps.CONST, dtypes.bool, arg=False)
c3 = UOp(UOps.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MAX)
c4 = UOp(UOps.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MUL)
c5 = UOp(UOps.CONST, dtypes.int, arg=-1)
c1 = UOp(Ops.CONST, dtypes.float, arg=0.0)
c2 = UOp(Ops.CONST, dtypes.bool, arg=False)
c3 = UOp(Ops.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MAX)
c4 = UOp(Ops.ALU, dtypes.float, (c1, c1), arg=BinaryOps.MUL)
c5 = UOp(Ops.CONST, dtypes.int, arg=-1)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), c2)
self.assertEqual(matcher.rewrite(c3), c3)
@@ -87,17 +87,17 @@ class TestPatternMatcher(unittest.TestCase):
def test_filter_arg(self):
matcher = PatternMatcher([
(UPat(UOps.ALU, arg=BinaryOps.MUL, src=[UPat(UOps.CONST, name="c"), UPat(UOps.CONST, arg=2)], name="x"),
(UPat(Ops.ALU, arg=BinaryOps.MUL, src=[UPat(Ops.CONST, name="c"), UPat(Ops.CONST, arg=2)], name="x"),
lambda x,c: x if c.arg in {1, -1} else None)
])
y1 = UOp(UOps.CONST, dtypes.int, arg=1)
y2 = UOp(UOps.CONST, dtypes.int, arg=2)
y3 = UOp(UOps.CONST, dtypes.int, arg=-1)
c1 = UOp(UOps.ALU, dtypes.int, (y1, y2), BinaryOps.MUL)
c2 = UOp(UOps.ALU, dtypes.int, (y2, y2), BinaryOps.MUL)
c3 = UOp(UOps.ALU, dtypes.int, (y3, y2), BinaryOps.MUL)
c4 = UOp(UOps.ALU, dtypes.int, (y2, y1), BinaryOps.MUL)
c5 = UOp(UOps.ALU, dtypes.int, (y2, y3), BinaryOps.MUL)
y1 = UOp(Ops.CONST, dtypes.int, arg=1)
y2 = UOp(Ops.CONST, dtypes.int, arg=2)
y3 = UOp(Ops.CONST, dtypes.int, arg=-1)
c1 = UOp(Ops.ALU, dtypes.int, (y1, y2), BinaryOps.MUL)
c2 = UOp(Ops.ALU, dtypes.int, (y2, y2), BinaryOps.MUL)
c3 = UOp(Ops.ALU, dtypes.int, (y3, y2), BinaryOps.MUL)
c4 = UOp(Ops.ALU, dtypes.int, (y2, y1), BinaryOps.MUL)
c5 = UOp(Ops.ALU, dtypes.int, (y2, y3), BinaryOps.MUL)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None)
self.assertEqual(matcher.rewrite(c3), c3)
@@ -105,37 +105,37 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c5), c5)
def test_dup_name(self):
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST, name="y"), UPat(UOps.CONST, name="y"))), lambda x, y: x)])
y1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
y2 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c1 = UOp(UOps.ALU, dtypes.float, (y1, y1), BinaryOps.ADD)
c2 = UOp(UOps.ALU, dtypes.float, (y1, y2), BinaryOps.ADD)
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST, name="y"), UPat(Ops.CONST, name="y"))), lambda x, y: x)])
y1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
y2 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c1 = UOp(Ops.ALU, dtypes.float, (y1, y1), BinaryOps.ADD)
c2 = UOp(Ops.ALU, dtypes.float, (y1, y2), BinaryOps.ADD)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), c1)
def test_dtype(self):
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float32), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float64, arg=1.0)
matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype=dtypes.float32), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float64, arg=1.0)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None)
def test_dtype_set(self):
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype={dtypes.float32, dtypes.float64}), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float64, arg=1.0)
c3 = UOp(UOps.CONST, dtypes.float16, arg=1.0)
c4 = UOp(UOps.CONST, dtypes.int, arg=1)
matcher = PatternMatcher([(UPat(Ops.CONST, name="x", dtype={dtypes.float32, dtypes.float64}), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float64, arg=1.0)
c3 = UOp(Ops.CONST, dtypes.float16, arg=1.0)
c4 = UOp(Ops.CONST, dtypes.int, arg=1)
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), c2)
self.assertEqual(matcher.rewrite(c3), None)
self.assertEqual(matcher.rewrite(c4), None)
def test_src_one(self):
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.CONST))), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST), UPat(Ops.CONST))), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
self.assertEqual(matcher.rewrite(c3), c3)
self.assertEqual(matcher.rewrite(c2), None)
# that CONST/ALU -> ALU/CONST rewrite is now instant
@@ -149,46 +149,46 @@ class TestPatternMatcher(unittest.TestCase):
"""
def test_src_permutations(self):
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=[UPat(UOps.CONST), UPat(UOps.ALU)]), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
c4 = UOp(UOps.ALU, dtypes.float, (c3,c2), BinaryOps.ADD)
c5 = UOp(UOps.ALU, dtypes.float, (c2,c3), BinaryOps.ADD)
c6 = UOp(UOps.ALU, dtypes.float, (c3,c4), BinaryOps.ADD)
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=[UPat(Ops.CONST), UPat(Ops.ALU)]), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
c4 = UOp(Ops.ALU, dtypes.float, (c3,c2), BinaryOps.ADD)
c5 = UOp(Ops.ALU, dtypes.float, (c2,c3), BinaryOps.ADD)
c6 = UOp(Ops.ALU, dtypes.float, (c3,c4), BinaryOps.ADD)
self.assertEqual(matcher.rewrite(c3), None)
self.assertEqual(matcher.rewrite(c4), c4)
self.assertEqual(matcher.rewrite(c5), c5)
self.assertEqual(matcher.rewrite(c6), None)
def test_src_repeat(self):
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=UPat(UOps.CONST)), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
c3 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
c4 = UOp(UOps.ALU, dtypes.float, (c2,c3), BinaryOps.ADD)
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=UPat(Ops.CONST)), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
c3 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
c4 = UOp(Ops.ALU, dtypes.float, (c2,c3), BinaryOps.ADD)
self.assertEqual(matcher.rewrite(c3), c3)
self.assertEqual(matcher.rewrite(c4), None)
def test_allow_len(self):
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
c3 = UOp(UOps.CONST, dtypes.float, arg=3.0)
c4 = UOp(UOps.ALU, dtypes.float, (c1,), UnaryOps.EXP2)
c5 = UOp(UOps.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
c6 = UOp(UOps.ALU, dtypes.float, (c1,c2,c3), TernaryOps.MULACC)
matcher = PatternMatcher([(UPat(Ops.ALU, name="x", src=(UPat(Ops.CONST),), allow_any_len=True, arg=TernaryOps.MULACC), lambda x: x)])
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
c3 = UOp(Ops.CONST, dtypes.float, arg=3.0)
c4 = UOp(Ops.ALU, dtypes.float, (c1,), UnaryOps.EXP2)
c5 = UOp(Ops.ALU, dtypes.float, (c1,c2), BinaryOps.ADD)
c6 = UOp(Ops.ALU, dtypes.float, (c1,c2,c3), TernaryOps.MULACC)
self.assertEqual(matcher.rewrite(c4), None)
self.assertEqual(matcher.rewrite(c5), None)
self.assertEqual(matcher.rewrite(c6), c6)
def test_deep_src_permutations(self):
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
c1 = UOp(Ops.CONST, dtypes.float, arg=1.0)
c2 = UOp(Ops.CONST, dtypes.float, arg=2.0)
u1 = (c1 + c2) + c1
u2 = (c2 + c1) + c1
matcher = PatternMatcher([
(UPat(UOps.ALU, src=[UPat(UOps.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')]), lambda a,b: b)
(UPat(Ops.ALU, src=[UPat(Ops.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')]), lambda a,b: b)
])
self.assertIsNotNone(matcher.rewrite(u1))
self.assertIsNotNone(matcher.rewrite(u2))

View File

@@ -5,14 +5,14 @@ from tinygrad.dtype import dtypes
from tinygrad.helpers import prod
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad import Variable
from tinygrad.ops import UOp, UOps, graph_rewrite
from tinygrad.ops import UOp, Ops, graph_rewrite
from tinygrad.codegen.uopgraph import sym
from itertools import product
def shapetracker_getitem(st:ShapeTracker, val:int):
idx, valid = st.reshape((st.size,)).to_indexed_uops([UOp.const(dtypes.int, val)])
idx, valid = graph_rewrite(idx, sym), graph_rewrite(valid, sym)
assert idx.op is UOps.CONST and valid.op is UOps.CONST
assert idx.op is Ops.CONST and valid.op is Ops.CONST
return idx.arg, valid.arg
class CheckingShapeTracker:

View File

@@ -3,27 +3,27 @@ from typing import Tuple
from tinygrad.codegen.uopgraph import full_graph_rewrite, is_increasing
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, UOps, simplify_valid
from tinygrad.ops import UOp, Ops, simplify_valid
def get_gated_load_uop(valid:UOp, idx:UOp):
return UOp(UOps.LOAD, dtypes.float, (
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
return UOp(Ops.LOAD, dtypes.float, (
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
idx,
UOp.const(dtypes.float, 0.0),
valid
))
def get_load_image_uop(image_shape:Tuple[int, ...], valid:UOp, idx:Tuple[UOp, UOp]):
return UOp(UOps.LOAD, dtypes.float.vec(4), (
UOp(UOps.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0),
UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx),
UOp(UOps.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0),)*4),
return UOp(Ops.LOAD, dtypes.float.vec(4), (
UOp(Ops.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0),
UOp(Ops.VECTORIZE, dtypes.int.vec(2), idx),
UOp(Ops.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0),)*4),
valid
))
def Special(expr, nmax): return UOp(UOps.SPECIAL, dtypes.int, (), (expr, nmax))
def Special(expr, nmax): return UOp(Ops.SPECIAL, dtypes.int, (), (expr, nmax))
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
def Range(n, nmax): return UOp(UOps.RANGE, dtypes.int, arg=(n, True), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, nmax),))
def Range(n, nmax): return UOp(Ops.RANGE, dtypes.int, arg=(n, True), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, nmax),))
class TestHelpers(unittest.TestCase):
def test_is_increasing(self):
@@ -43,7 +43,7 @@ class TestHelpers(unittest.TestCase):
self.assertTrue(is_increasing(f2))
self.assertTrue(is_increasing(f3))
rng = UOp(UOps.RANGE, dtypes.int, arg=(2, True), src=(UOp(UOps.CONST, dtypes.int, arg=0, src=()), UOp(UOps.CONST, dtypes.int, arg=5, src=()),))
rng = UOp(Ops.RANGE, dtypes.int, arg=(2, True), src=(UOp(Ops.CONST, dtypes.int, arg=0, src=()), UOp(Ops.CONST, dtypes.int, arg=5, src=()),))
self.assertTrue(is_increasing(rng))
self.assertTrue(is_increasing(rng+2))
@@ -87,7 +87,7 @@ class TestImageSimplification(unittest.TestCase):
def check(self, load, svalid, sidx0, sidx1):
load = full_graph_rewrite(load.sink()).src[0]
idx = load.src[0].src[1]
self.assertEqual(idx.op, UOps.VECTORIZE)
self.assertEqual(idx.op, Ops.VECTORIZE)
self.assertEqual(len(idx.src), 2)
idx0, idx1 = idx.src[0], idx.src[1]
self.assertEqual(idx0.render(simplify=False), sidx0)
@@ -152,7 +152,7 @@ class TestImageSimplification(unittest.TestCase):
# empty -> invalid
load = get_load_image_uop(shape, (gidx0).lt(8) & (gidx0).lt(8).ne(True), idx)
load = full_graph_rewrite(load.sink()).src[0]
self.assertEqual(load.op, UOps.VECTORIZE)
self.assertEqual(load.op, Ops.VECTORIZE)
self.assertEqual(load.dtype.count, 4)
def test_openpilot_conv1(self):

View File

@@ -9,15 +9,15 @@ from typing import Tuple
from tinygrad.dtype import dtypes, ConstType
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
from tinygrad.ops import UOp, UOps, graph_rewrite, sym_infer
from tinygrad.ops import UOp, Ops, graph_rewrite, sym_infer
from tinygrad import Variable
import functools
def render(self) -> Tuple[str, ConstType, ConstType]:
# NOTE: we need STORE so the ALU op has children
glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
uops = linearize_uop(full_graph_rewrite(UOp(UOps.STORE, dtypes.void, (glbl, UOp.const(dtypes.int, 0), self)).sink()))
rewritten_uop = [uop for uop in uops if uop.op is UOps.STORE][0].src[-1]
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
uops = linearize_uop(full_graph_rewrite(UOp(Ops.STORE, dtypes.void, (glbl, UOp.const(dtypes.int, 0), self)).sink()))
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[-1]
return rewritten_uop.render(simplify=False), rewritten_uop.vmin, rewritten_uop.vmax
def NumNode(val): return UOp.const(dtypes.int, val)

View File

@@ -1,5 +1,5 @@
import unittest, math
from tinygrad.ops import UOp, UOps
from tinygrad.ops import UOp, Ops
from tinygrad.dtype import dtypes
class TestVminVmaxProperties(unittest.TestCase):
@@ -35,7 +35,7 @@ class TestVminVmaxProperties(unittest.TestCase):
def test_vmin_vmax_multiplication_0_inf(self):
# vmin and vmax for multiplication with a variable
x = UOp.const(dtypes.float, 0.0)
y = UOp.load(UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0), UOp.const(dtypes.int, 0), dtype=dtypes.float)
y = UOp.load(UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0), UOp.const(dtypes.int, 0), dtype=dtypes.float)
uop = x * y
# TODO: these should be 0, but definitely should not be nan
self.assertEqual(uop.vmin, -math.inf)

View File

@@ -4,7 +4,7 @@ import unittest
from tinygrad import Tensor
from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import DEBUG
from tinygrad.ops import UOp, UOps, ReduceOps, print_uops
from tinygrad.ops import UOp, Ops, ReduceOps, print_uops
from tinygrad.codegen.kernel import verify_ast
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad import dtypes
@@ -12,7 +12,7 @@ from tinygrad.shape.view import View
class InvalidASTException(Exception): pass
def helper_test_verify_ast(*stores:UOp) -> Kernel:
sink = UOp(UOps.SINK, dtypes.void, stores)
sink = UOp(Ops.SINK, dtypes.void, stores)
if DEBUG >= 3:
for op in stores: print(op)
try: verify_ast(sink)
@@ -26,59 +26,59 @@ def helper_test_verify_ast(*stores:UOp) -> Kernel:
class TestVerifyAST(unittest.TestCase):
def test_tiny_add(self):
dtype = dtypes.int
buf_0 = UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), 0)
buf_1 = UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), 1)
buf_2 = UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), 2)
a = UOp(UOps.LOAD, dtype, (buf_1, ShapeTracker.from_shape((32, 1)).to_uop()))
b = UOp(UOps.LOAD, dtype, (buf_2, ShapeTracker.from_shape((32, 1)).to_uop()))
store = UOp(UOps.STORE, dtypes.void, (buf_0, ShapeTracker.from_shape((32, 1)).to_uop(), a+b))
buf_0 = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 0)
buf_1 = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 1)
buf_2 = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), 2)
a = UOp(Ops.LOAD, dtype, (buf_1, ShapeTracker.from_shape((32, 1)).to_uop()))
b = UOp(Ops.LOAD, dtype, (buf_2, ShapeTracker.from_shape((32, 1)).to_uop()))
store = UOp(Ops.STORE, dtypes.void, (buf_0, ShapeTracker.from_shape((32, 1)).to_uop(), a+b))
helper_test_verify_ast(store)
def test_exactly_one_full_shape(self):
dtype = dtypes.int
bufs = [UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), i) for i in range(6)]
a = UOp(UOps.LOAD, dtype, (bufs[2], ShapeTracker.from_shape((32, 1)).to_uop()))
b = UOp(UOps.LOAD, dtype, (bufs[3], ShapeTracker.from_shape((32, 1)).to_uop()))
bufs = [UOp(Ops.DEFINE_GLOBAL, dtype.ptr(), (), i) for i in range(6)]
a = UOp(Ops.LOAD, dtype, (bufs[2], ShapeTracker.from_shape((32, 1)).to_uop()))
b = UOp(Ops.LOAD, dtype, (bufs[3], ShapeTracker.from_shape((32, 1)).to_uop()))
st0 = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), a+b)
a = UOp(UOps.LOAD, dtype, (bufs[4], ShapeTracker.from_shape((32, 32)).to_uop()))
b = UOp(UOps.LOAD, dtype, (bufs[5], ShapeTracker.from_shape((32, 32)).to_uop()))
a = UOp(Ops.LOAD, dtype, (bufs[4], ShapeTracker.from_shape((32, 32)).to_uop()))
b = UOp(Ops.LOAD, dtype, (bufs[5], ShapeTracker.from_shape((32, 32)).to_uop()))
st1 = UOp.store(bufs[1], ShapeTracker.from_shape((32, 32)).to_uop(), a+b)
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st0, st1)
def test_no_implicit_broadcasting(self):
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop()))
b = a + UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.MAX, (1,)))
st = UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b))
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop()))
b = a + UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.MAX, (1,)))
st = UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b))
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
def test_shrink_ok(self):
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)).to_uop()))
b = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(0, 1), offset=0, mask=None, contiguous=False),)).to_uop()))
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)).to_uop()))
b = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(0, 1), offset=0, mask=None, contiguous=False),)).to_uop()))
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), a+b)
helper_test_verify_ast(st)
def test_reduce_store(self):
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
r = UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r)
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
def test_reduce_add_store(self):
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
r = UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
a = UOp(Ops.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r+a)
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
def test_buffer_uops_st(self):
a = Tensor.randn(4, 4)+2
uop_sts = verify_ast(a.schedule()[-1].ast)
store_st = [st for u,st in uop_sts.items() if u.op is UOps.STORE][0]
store_st = [st for u,st in uop_sts.items() if u.op is Ops.STORE][0]
self.assertEqual(store_st, ShapeTracker.from_shape((4, 4)))
const_st = [st for u,st in uop_sts.items() if u.op is UOps.VALID][0]
const_st = [st for u,st in uop_sts.items() if u.op is Ops.VALID][0]
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
if __name__ == '__main__':

View File

@@ -5,7 +5,7 @@ from collections import defaultdict
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callable, Sequence
from enum import Enum, auto
from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp, UOps, PatternMatcher, print_uops, type_verify, resolve, \
from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp, Ops, PatternMatcher, print_uops, type_verify, resolve, \
graph_rewrite, track_rewrites, Variable, sint
from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore, Program
@@ -54,7 +54,7 @@ class TensorCoreOptions:
class Kernel:
def __init__(self, ast:UOp, opts:Optional[Renderer]=None):
if ast.op is UOps.SINK: self.ast = ast
if ast.op is Ops.SINK: self.ast = ast
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
try: uop_sts_map = verify_ast(self.ast)
@@ -65,7 +65,7 @@ class Kernel:
@functools.lru_cache(None)
def ordered_parents(op:UOp) -> List[UOp]: return dedup([item for x in op.src for item in ordered_parents(x)] + [op])
self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is UOps.REDUCE_AXIS])
self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is Ops.REDUCE_AXIS])
self.vars: List[Variable] = self.ast.variables()
self.bufs: List[UOp] = [x for x in self.ast.parents if x.op in BUFFER_UOPS]
@@ -125,7 +125,7 @@ class Kernel:
return ret
@property
def membufs(self) -> List[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {UOps.LOAD, UOps.STORE}])
def membufs(self) -> List[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
# TODO: these need more tests or it might silently be no-op
def float4_axis(self, i:int): return [x-self.first_upcast for x in self.sts[i].unit_stride_axes() if x >= self.first_upcast and self.sts[i].shape[x]%4 == 0] # noqa: E501
@@ -273,16 +273,16 @@ class Kernel:
def _create_tc_opts(self, reduceop:UOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]:
has_cast = tc.dtype_in != tc.dtype_out
if has_cast and not (reduceop.src[0].op is UOps.CAST and reduceop.src[0].dtype == tc.dtype_out): return None
if has_cast and not (reduceop.src[0].op is Ops.CAST and reduceop.src[0].dtype == tc.dtype_out): return None
mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
if mul_op.arg is not BinaryOps.MUL: return None
def buf_index(src:UOp) -> Optional[int]:
# TODO: apply tc even if the sources are not from LOAD
if src.op is UOps.LOAD and src.dtype == tc.dtype_in: return self.bufs.index(src)
if src.op is Ops.LOAD and src.dtype == tc.dtype_in: return self.bufs.index(src)
try:
if opt_level >= 1 and src.op is UOps.CAST and src.dtype == tc.dtype_in: return self.bufs.index(src.src[0])
if opt_level >= 1 and src.op is Ops.CAST and src.dtype == tc.dtype_in: return self.bufs.index(src.src[0])
except ValueError: return None
return None
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None
@@ -442,7 +442,7 @@ class Kernel:
check(axis < self.first_upcast, "cannot pad upcasted")
# ok to pad SUM if all parent ALU ops have f(0) = 0
if (r:=self.reduceop) is not None and self.first_reduce <= axis:
check(r.arg[0] is BinaryOps.ADD and not any(u.op is UOps.ALU and u.arg in UNSAFE_PAD_OPS for u in r.parents), "cannot pad UNSAFE_PAD_OPS")
check(r.arg[0] is BinaryOps.ADD and not any(u.op is Ops.ALU and u.arg in UNSAFE_PAD_OPS for u in r.parents), "cannot pad UNSAFE_PAD_OPS")
padded = False
for i,st in enumerate(self.sts):
if (s:=st.shape[axis]) == 1: continue # reduced
@@ -472,7 +472,7 @@ class Kernel:
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
self.reduceop is not None and self.reduceop.arg[0] is BinaryOps.ADD and len(self.full_shape) >= 2 and self.opts.has_shared and \
(mulop:=self.reduceop.src[0]).arg is BinaryOps.MUL and mulop.src[0].op is UOps.LOAD and mulop.src[1].op is UOps.LOAD:
(mulop:=self.reduceop.src[0]).arg is BinaryOps.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD:
st0, st1 = self.sts[self.bufs.index(mulop.src[0])], self.sts[self.bufs.index(mulop.src[1])]
strides0, strides1 = st0.real_strides(), st1.real_strides()
def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides))
@@ -615,20 +615,20 @@ class Kernel:
arg = op.arg
if op.op in BUFFER_UOPS:
# for locals, we use the ShapeTracker that's in the srcs
st = op.st_arg if op.src[0].op is UOps.DEFINE_LOCAL else self.sts[self.bufs.index(op)]
st = op.st_arg if op.src[0].op is Ops.DEFINE_LOCAL else self.sts[self.bufs.index(op)]
st_uop = (st if apply_to_st is None else apply_to_st(st)).to_uop()
if op.op is UOps.VALID: return op.replace(src=(st_uop,))
if op.op is UOps.STORE: return op.replace(src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st)))
if op.op is Ops.VALID: return op.replace(src=(st_uop,))
if op.op is Ops.STORE: return op.replace(src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st)))
return op.replace(src=(op.src[0], st_uop, *[fixup_ast(x, apply_to_st) for x in op.src[2:]]))
if op.op is UOps.REDUCE_AXIS:
if op.op is Ops.REDUCE_AXIS:
reduce_idx = len(self.bufs) + self.reduceops.index(op)*2
alu_op: BinaryOps = op.arg[0]
axis = tuple(i for i in range(self.first_reduce+self.group_for_reduces, self.shape_len)
if resolve(self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i]))
if op in self.bufs_for_tensor_core and (tc := self.tensor_core):
rsrc = op.src[0]
if rsrc.op is UOps.CAST: rsrc = rsrc.src[0]
assert rsrc.op is UOps.ALU and rsrc.arg is BinaryOps.MUL
if rsrc.op is Ops.CAST: rsrc = rsrc.src[0]
assert rsrc.op is Ops.ALU and rsrc.arg is BinaryOps.MUL
def fix_st(warp_dims, tcd_dims, tcd_expand, pattern_1, pattern_2, st1):
wd, tcd = self.global_dims, self.first_upcast
@@ -654,30 +654,30 @@ class Kernel:
for i,s in enumerate(self.full_shape))
srcs = []
for i,(src,fix_st_fxn) in enumerate(zip(rsrc.src, [fix_st1, fix_st2])):
st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD]
st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is Ops.LOAD]
local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape))
st_uop = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uop()
membuf = UOp(UOps.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (), (f"temp{-(-1-i)}", st_uop.arg.real_size()))
local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, st_uop, src)), fix_st_fxn)
srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, st_uop, local_store)))
membuf = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (), (f"temp{-(-1-i)}", st_uop.arg.real_size()))
local_store = fixup_ast(UOp(Ops.STORE, tc.dtype_in, (membuf, st_uop, src)), fix_st_fxn)
srcs.append(UOp(Ops.LOAD, tc.dtype_in, (membuf, st_uop, local_store)))
else:
# for TC=2, we can't do the shapetracker fixup
srcs = [fixup_ast(rsrc.src[0]), fixup_ast(rsrc.src[1])]
# MUL/SUM instead of WMMA
ret = UOp(UOps.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(BinaryOps.MUL, srcs[1]).cast(tc.dtype_out),), (alu_op, wmma_arg[-1]))
ret = UOp(Ops.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(BinaryOps.MUL, srcs[1]).cast(tc.dtype_out),), (alu_op, wmma_arg[-1]))
else:
# real WMMA, use CONTRACT/EXPAND to get the vectorization right
wmma_upcast_axes = wmma_arg[-2]
wmma_sz = [prod(x[1] for x in l) for l in wmma_upcast_axes]
wmma = UOp(UOps.WMMA, dtype=tc.dtype_out.vec(wmma_sz[2]), src=(
UOp(UOps.CONTRACT, dtype=rsrc.src[0].dtype.vec(wmma_sz[0]), src=(fixup_ast(rsrc.src[0], fix_st1),), arg=wmma_upcast_axes[0]),
UOp(UOps.CONTRACT, dtype=rsrc.src[1].dtype.vec(wmma_sz[1]), src=(fixup_ast(rsrc.src[1], fix_st2),), arg=wmma_upcast_axes[1]),
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(wmma_sz[2]), src=(
UOp(Ops.CONTRACT, dtype=rsrc.src[0].dtype.vec(wmma_sz[0]), src=(fixup_ast(rsrc.src[0], fix_st1),), arg=wmma_upcast_axes[0]),
UOp(Ops.CONTRACT, dtype=rsrc.src[1].dtype.vec(wmma_sz[1]), src=(fixup_ast(rsrc.src[1], fix_st2),), arg=wmma_upcast_axes[1]),
UOp.const(tc.dtype_out.vec(wmma_sz[2]), 0.0)), arg=wmma_arg)
ret = UOp(UOps.EXPAND, tc.dtype_out, (wmma,), arg=wmma_upcast_axes[2])
ret = UOp(Ops.EXPAND, tc.dtype_out, (wmma,), arg=wmma_upcast_axes[2])
new_reduce_axes = tuple(i for i in axis if i-self.first_upcast not in [ax for ax, _ in tc.reduce_axes])
return op.replace(src=(ret,), arg=(alu_op, new_reduce_axes)) if new_reduce_axes else ret
if self.group_for_reduces:
start = UOp(UOps.REDUCE_AXIS, op.dtype, (fixup_ast(op.src[0], apply_to_st),), arg=(alu_op, axis))
start = UOp(Ops.REDUCE_AXIS, op.dtype, (fixup_ast(op.src[0], apply_to_st),), arg=(alu_op, axis))
second_axis = tuple(i for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces) \
if self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i])
# NOTE: if there's a grouped reduce, but no reduce axes for this reduce, we can skip it
@@ -687,14 +687,14 @@ class Kernel:
for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
(1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
st_uop = ShapeTracker.from_shape(local_shape).to_uop()
local_buffer = UOp(UOps.DEFINE_LOCAL, op.dtype.ptr(local=True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, start)))
grouped_reduce = UOp(UOps.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], second_axis))
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local=True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, start)))
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], second_axis))
if op is self.reduceops[-1]: return grouped_reduce
st_uop = ShapeTracker.from_shape(tuple([1 if i in second_axis else a for i,a in enumerate(local_shape)])).to_uop()
return UOp(UOps.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, grouped_reduce)))
return UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, grouped_reduce)))
arg = (alu_op, axis)
elif op.op is UOps.SINK:
elif op.op is Ops.SINK:
arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals)
return op.replace(src=tuple(fixup_ast(x, apply_to_st) for x in op.src), arg=arg)
# NOTE: rewrite with an empty PatternMatcher to dedup UOps
@@ -728,7 +728,7 @@ class Kernel:
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
# TODO: these max and min don't work on symbolic, and results are very wrong.
mem_bytes = sum(max(x.src[0].dtype.itemsize * x.st_arg.real_size() for x in group)
for _, group in itertools.groupby([x for x in self.ast.parents if x.op in BUFFER_UOPS and x.src[0].op is UOps.DEFINE_GLOBAL],
for _, group in itertools.groupby([x for x in self.ast.parents if x.op in BUFFER_UOPS and x.src[0].op is Ops.DEFINE_GLOBAL],
key=lambda x: (x.op, x.src[0].arg)))
return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
@@ -738,15 +738,15 @@ class Kernel:
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None:
if not uop.has_st or uop in sts: return
# restore globals from the two stage reduce
if uop.op is UOps.LOAD and uop.src[0].op is UOps.DEFINE_LOCAL:
if uop.op is Ops.LOAD and uop.src[0].op is Ops.DEFINE_LOCAL:
_assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts)
sts[uop] = sts[local_reduce]
return
for x in uop.src: _assert_valid_uop(x, st, sts)
# only reduceuop is allowed to change shape, limited to turning n to 1
if uop.op in {UOps.REDUCE_AXIS, UOps.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg))
if uop.op in {Ops.REDUCE_AXIS, Ops.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.axis_arg))
# movementops are pushed to VIEW
elif uop.op is UOps.VIEW: st = uop.arg
elif uop.op is Ops.VIEW: st = uop.arg
# everything else inherits shape
else:
st = (src_sts:=[sts[x] for x in uop.src if x.has_st])[0]
@@ -756,7 +756,7 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) ->
sts[uop] = st
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK"
assert ast.op is Ops.SINK and all(x.op is Ops.STORE for x in ast.src), "must be SINK"
assert all_same([x.st_arg.size for x in ast.src]), "outputs must be exactly the same size"
sts: Dict[UOp, ShapeTracker] = {}
for out in ast.src: _assert_valid_uop(out, out.st_arg, sts)

View File

@@ -1,6 +1,6 @@
from typing import List, Set, Dict, Tuple
import functools, heapq
from tinygrad.ops import type_verify, END_FOR_UOP, UOp, UOps
from tinygrad.ops import type_verify, END_FOR_UOP, UOp, Ops
from tinygrad.dtype import dtypes
from tinygrad.helpers import DEBUG
@@ -10,13 +10,13 @@ def get_children_dfs(u:UOp, children:Dict[UOp, List[UOp]], srcs:Dict[UOp, Dict[U
children[u] = []
for x in u.src:
srcs[u].update(get_children_dfs(x, children, srcs, in_degree))
if x.op is UOps.RANGE and x.arg[1]: srcs[u][x] = None
if x.op is Ops.RANGE and x.arg[1]: srcs[u][x] = None
children[x].append(u)
in_degree[u] = len(u.src)
return srcs[u]
def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}"
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
# filter nodes that don't link to a sink
# BFS toposort
children: Dict[UOp, List[UOp]] = {}
@@ -25,38 +25,38 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
get_children_dfs(sink, children, range_srcs, in_degree)
@functools.lru_cache(None)
def get_recursive_children(x:UOp, end:UOps, include_self=False) -> Set[UOp]:
if x.op is UOps.SINK: return set()
def get_recursive_children(x:UOp, end:Ops, include_self=False) -> Set[UOp]:
if x.op is Ops.SINK: return set()
return set.union({x} if include_self else set(), *([get_recursive_children(u, end, True) for u in children[x] if x.op is not end]))
# scope children impact the toposort and END* insertion
scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP}
range_phi = {r:[p for p in scope_children[r] if p.op is UOps.ASSIGN] for r in scope_children if r.op is UOps.RANGE}
range_phi = {r:[p for p in scope_children[r] if p.op is Ops.ASSIGN] for r in scope_children if r.op is Ops.RANGE}
# assign priorities
def get_priority(u:UOp):
priority = 0
# prefer ranges that depend on the least number of independent ranges
if u.op is UOps.RANGE and u.arg[1]:
if u.op is Ops.RANGE and u.arg[1]:
priority += u.arg[0]
for p in range_phi[u]:
priority += 10000*len([r for r in range_srcs[p] if not any(i in range_phi[u] for i in range_phi[r])])
elif u.op is UOps.CONST:
elif u.op is Ops.CONST:
# place consts first here, they don't do anything and it can cause issues with DEFINE_ACC
priority -= 100000000000
else:
# prefer uops that are loop children
priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is UOps.RANGE and u in ss])
if u.op is UOps.IF and len(u.src) == 1: priority += 10000000 # if penalty
priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is Ops.RANGE and u in ss])
if u.op is Ops.IF and len(u.src) == 1: priority += 10000000 # if penalty
return priority
priorities:Dict[UOp, int] = {u:get_priority(u) for u in children}
# prevent priority inversion
@functools.lru_cache(None)
def fix_priority(u:UOp, lowest_priority):
if u.op in {UOps.CAST, UOps.BITCAST, UOps.ALU, UOps.VECTORIZE, UOps.GEP, UOps.SPECIAL, UOps.DEFINE_LOCAL, UOps.LOAD}:
if u.op in {Ops.CAST, Ops.BITCAST, Ops.ALU, Ops.VECTORIZE, Ops.GEP, Ops.SPECIAL, Ops.DEFINE_LOCAL, Ops.LOAD}:
priorities[u] = min(priorities[u], lowest_priority)
if u.op is UOps.LOAD: priorities[u] += 100 # load penalty (here)
if u.op is Ops.LOAD: priorities[u] += 100 # load penalty (here)
for x in u.src: fix_priority(x, priorities[u])
fix_priority(sink, 0)
@@ -73,8 +73,8 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
p,_,x = heapq.heappop(queue)
if DEBUG >= 7: print(f"{p:5d}", x.op, x.dtype, x.arg)
if x in scope_children: scope_end[x] = x
if x.op is UOps.DEFINE_ACC:
idx = min([_uops.index(l) for l in x.src if l.op is UOps.RANGE])
if x.op is Ops.DEFINE_ACC:
idx = min([_uops.index(l) for l in x.src if l.op is Ops.RANGE])
_uops.insert(idx, x)
else: _uops.append(x)
for u, ss in scope_children.items():

View File

@@ -6,7 +6,7 @@ from typing import List, Tuple, cast, Optional
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import variable_to_uop
from tinygrad.dtype import dtypes
from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat, sint, identity_element
from tinygrad.ops import KernelInfo, BinaryOps, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element
from tinygrad.renderer import Renderer
from tinygrad.helpers import all_int, prod, partition, flatten
@@ -33,7 +33,7 @@ def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse=False) -> List[UOp]:
if reverse: dims = dims[::-1]
limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims
ret = raw_idxs = [UOp(UOps.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
ret = raw_idxs = [UOp(Ops.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
if limited != dims:
ret = []
# cast for mypy, get_contraction won't be None
@@ -57,8 +57,8 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
first_upcasted = len(full_shape)-ki.upcasted
first_output_st: ShapeTracker = ast.src[0].st_arg
# if there's no reduce, this is first_upcasted. assumes reduces are at the end
first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.sparents if x.op is UOps.REDUCE_AXIS))
local_loads = [x for x in ast.parents if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL]
first_reduce = min([first_upcasted]+flatten(x.axis_arg for x in ast.sparents if x.op is Ops.REDUCE_AXIS))
local_loads = [x for x in ast.parents if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL]
# NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
group_for_reduces = sum([any(j!=y for j in x) for x,y in zip(
[[l.st_arg.shape[i] for l in local_loads] for i in range(first_reduce,first_upcasted)],
@@ -75,22 +75,22 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
else:
# all loops are RANGES
idxs = [UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, False))
idxs = [UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, False))
for i,g in enumerate(full_shape[:first_reduce])]
# reduce loops
idxs += [UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, True))
idxs += [UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, True))
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
# upcast loops
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
assert isinstance(g, int), "needs to be int to upcast/unroll"
idxs.append(UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))
idxs.append(UOp(Ops.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))
# late indexes (group for reduce)
ridxs = idxs[:]
for a in range(first_reduce, first_reduce+group_for_reduces):
ridxs[a] = UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(full_shape[a])), (1000+a, True))
ridxs[a] = UOp(Ops.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(full_shape[a])), (1000+a, True))
return IndexContext(idxs, ridxs)
@@ -98,50 +98,50 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
def lower_reduce_axis(ctx: IndexContext, x: UOp):
# NOTE: always using ridxs is fine here
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is UOps.RANGE)
assert all(x.op is UOps.EXPAND for x in reduce_expand), f"not all EXPANDS in {reduce_expand} for {x.axis_arg}"
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.axis_arg], lambda y: y.op is Ops.RANGE)
assert all(x.op is Ops.EXPAND for x in reduce_expand), f"not all EXPANDS in {reduce_expand} for {x.axis_arg}"
alu_op: BinaryOps = x.arg[0]
ret = x.src[0]
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
ret = UOp(UOps.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
ret = UOp(Ops.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(ret.dtype.count)])
return UOp(UOps.REDUCE, x.dtype, (ret,) + tuple(reduce_range), alu_op) if len(reduce_range) else ret
return UOp(Ops.REDUCE, x.dtype, (ret,) + tuple(reduce_range), alu_op) if len(reduce_range) else ret
def lower_load_store(ctx: IndexContext, x: UOp):
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else ctx.idxs)
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is Ops.LOAD and x.src[0].op is Ops.DEFINE_LOCAL else ctx.idxs)
# TODO: check has_valid in UPat, not here
has_valid = valid.op is not UOps.CONST or valid.arg is not True
has_valid = valid.op is not Ops.CONST or valid.arg is not True
buf = x.src[0]
if x.op is UOps.LOAD:
barrier = (UOp(UOps.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is UOps.DEFINE_LOCAL else ()
return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((x.const_like(0), valid) if has_valid else ()) + barrier)
if x.op is Ops.LOAD:
barrier = (UOp(Ops.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is Ops.DEFINE_LOCAL else ()
return UOp(Ops.LOAD, x.dtype, (buf, idx) + ((x.const_like(0), valid) if has_valid else ()) + barrier)
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
store_back = x.src[0].op is UOps.DEFINE_LOCAL and x.src[2].op is UOps.REDUCE and \
x.src[2].src[0].op is UOps.LOAD and x.src[2].src[0].src[0].op is UOps.DEFINE_LOCAL
store_back = x.src[0].op is Ops.DEFINE_LOCAL and x.src[2].op is Ops.REDUCE and \
x.src[2].src[0].op is Ops.LOAD and x.src[2].src[0].src[0].op is Ops.DEFINE_LOCAL
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[2].src else u for u in ctx.idxs])
if x.src[0].op is UOps.DEFINE_GLOBAL or store_back:
if x.src[0].op is Ops.DEFINE_GLOBAL or store_back:
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
if oidx is not ridx: valid = valid * oidx.eq(0)
has_valid = valid.op is not UOps.CONST or valid.arg is not True
return UOp(UOps.STORE, dtypes.void, (buf, idx, x.src[2]) + ((valid,) if has_valid else ()))
has_valid = valid.op is not Ops.CONST or valid.arg is not True
return UOp(Ops.STORE, dtypes.void, (buf, idx, x.src[2]) + ((valid,) if has_valid else ()))
pm_lowerer = PatternMatcher([
(UPat(UOps.REDUCE_AXIS, name="x"), lower_reduce_axis),
(UPat(UOps.VALID, src=(UPat(UOps.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
(UPat(Ops.REDUCE_AXIS, name="x"), lower_reduce_axis),
(UPat(Ops.VALID, src=(UPat(Ops.VIEW),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
# rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
(UPat((UOps.LOAD, UOps.STORE), src=(UPat(), UPat(UOps.VIEW)), allow_any_len=True, name="x"), lower_load_store),
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
])
def do_reduce(ctx:List[int], root:UOp):
acc = UOp(UOps.DEFINE_ACC, root.dtype,
acc = UOp(Ops.DEFINE_ACC, root.dtype,
(root.const_like(identity_element(root.arg, root.dtype.scalar())),) + tuple(root.src[1:]), (ctx[0],))
ctx[0] += 1
return acc.assign(acc.alu(root.arg, root.src[0]))
just_reduce = PatternMatcher([
# do reduce
(UPat(UOps.REDUCE, name="root"), do_reduce),
(UPat(Ops.REDUCE, name="root"), do_reduce),
])
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:

View File

@@ -3,7 +3,7 @@ from typing import Optional, Tuple, Dict, List, TYPE_CHECKING, Any, DefaultDict,
import functools, itertools, operator
from collections import defaultdict
from tinygrad.dtype import dtypes, ImageDType, PtrDType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, UOps, UPat, PatternMatcher, symbolic_flat, symbolic_simple
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple
from tinygrad.ops import graph_rewrite, is_irreducible, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
@@ -16,15 +16,15 @@ def fold_expanded(ex, buf):
if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType): return None
new_srcs = dedup(list(ex.src))
old_new_srcs = new_srcs[:]
is_load, is_image = new_srcs[0].op is UOps.LOAD, isinstance(buf.dtype, ImageDType)
is_load, is_image = new_srcs[0].op is Ops.LOAD, isinstance(buf.dtype, ImageDType)
# first, extract all the relevant offsets
offsets_rootsrc: DefaultDict[Any, dict] = defaultdict(dict)
for i,s in enumerate(new_srcs):
idx = s.src[0].src[1]
if s.dtype.count != 1 or (is_image and idx.dtype.count == 2): continue
if idx.arg is BinaryOps.ADD and idx.src[1].op is UOps.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is UOps.CONST: root_src, arg = "CONST", idx.arg
if idx.arg is BinaryOps.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
else: root_src, arg = idx, 0
# add gates for gated
if len(s.src[0].src) == 3: root_src = (s.src[0].src[2], root_src)
@@ -45,20 +45,20 @@ def fold_expanded(ex, buf):
if is_image:
# for images, we rewrite the index. it must evenly divide 4 from the above check
new_src[0] = buf.index(
UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
rootsrc[0] if isinstance(rootsrc, tuple) else None)
else:
# for non image, we upcast the index pointer
new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(new_src[0].dtype.local))
# vectorize the store
if not is_load:
new_src[1] = UOp(UOps.VECTORIZE, new_src[1].dtype.vec(fold_length), tuple(new_srcs[offsets[o+i]].src[1] for i in range(fold_length)))
new_src[1] = UOp(Ops.VECTORIZE, new_src[1].dtype.vec(fold_length), tuple(new_srcs[offsets[o+i]].src[1] for i in range(fold_length)))
# generate the folded new_srcs
if is_load:
new_load = UOp(UOps.LOAD, load_1.dtype.vec(fold_length), tuple(new_src))
new_load = UOp(Ops.LOAD, load_1.dtype.vec(fold_length), tuple(new_src))
for i in range(fold_length): new_srcs[offsets[o+i]] = new_load.gep(i)
else:
for i in range(fold_length): new_srcs[offsets[o+i]] = UOp(UOps.STORE, dtypes.void, tuple(new_src)) if i == 0 else None
for i in range(fold_length): new_srcs[offsets[o+i]] = UOp(Ops.STORE, dtypes.void, tuple(new_src)) if i == 0 else None
for i in range(fold_length): used.add((rootsrc,o+i))
# dedup expand for LOAD
@@ -72,15 +72,15 @@ def fix_unfoldable_image_load(load:UOp, buf:UOp):
new_src = list(load.src)
# TODO: copied logic from above
new_src[0] = load.src[0].src[0].index(
UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
UOp(Ops.VECTORIZE, dtypes.int.vec(2), ((oidx // 4) % buf.dtype.shape[1], (oidx // (4*buf.dtype.shape[1])))),
load.src[0].src[2] if len(load.src[0].src) == 3 else None)
vec_load = UOp(UOps.LOAD, load.dtype.vec(4), tuple(new_src))
vec_load = UOp(Ops.LOAD, load.dtype.vec(4), tuple(new_src))
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, vec_load.gep(i)), range(4), load.const_like(float('nan')))
buf_idx_pat = UPat(UOps.INDEX, src=(UPat.var("buf"),), allow_any_len=True)
buf_idx_pat = UPat(Ops.INDEX, src=(UPat.var("buf"),), allow_any_len=True)
float4_folding = PatternMatcher([
(UPat(UOps.VECTORIZE, src=UPat(UOps.LOAD, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
(UPat((UOps.BARRIER, UOps.SINK), src=UPat(UOps.STORE, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
(UPat(Ops.VECTORIZE, src=UPat(Ops.LOAD, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
(UPat((Ops.BARRIER, Ops.SINK), src=UPat(Ops.STORE, src=(buf_idx_pat,), allow_any_len=True), name="ex"), fold_expanded),
])
# ***** image load valid simplification *****
@@ -124,18 +124,18 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]:
powers_of_two = {2**i:i for i in range(64)}
@functools.lru_cache(None)
def get_late_rewrite_patterns(ops, force_transcendental=False):
pat: List[Tuple[UPat, Callable]] = [(UPat(UOps.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=op), f) for op,f in \
pat: List[Tuple[UPat, Callable]] = [(UPat(Ops.ALU, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),), arg=op), f) for op,f in \
((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if op not in ops or force_transcendental]
# rewrite MOD to AND (which should always be supported, but not for generic in tests)
if BinaryOps.AND in ops:
pat += [(UPat(UOps.ALU, arg=BinaryOps.MOD, src=(UPat.var('base'), UPat.cvar("const"))),
pat += [(UPat(Ops.ALU, arg=BinaryOps.MOD, src=(UPat.var('base'), UPat.cvar("const"))),
lambda base,const: base & (const.arg-1) if const.arg in powers_of_two else None)]
# rewrite MUL/IDIV to SHL+SHR
if BinaryOps.SHL in ops and BinaryOps.SHR in ops:
pat += [
(UPat(UOps.ALU, arg=BinaryOps.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const:
(UPat(Ops.ALU, arg=BinaryOps.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const:
mul << powers_of_two[const.arg] if const.arg in powers_of_two else None), # (x * (2**y)) -> shl(x,y)
(UPat(UOps.ALU, arg=BinaryOps.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const:
(UPat(Ops.ALU, arg=BinaryOps.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const:
div >> powers_of_two[const.arg] if const.arg in powers_of_two else None)] # (x // (2**y)) -> shr(x,y)
if UnaryOps.NEG in ops:
pat += [(UPat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))]
@@ -175,8 +175,8 @@ def loop_collapse(compval, multconst, rng:UOp, acc:UOp, idx2=None,idx3=None,extr
if vec is not None:
# add, mul, loop_start, loop_end
def dvec(x:UOp):
if x.op is UOps.CONST: return UOp.const(x.dtype.vec(vec.dtype.count), x.arg)
return UOp(UOps.VECTORIZE, x.dtype.vec(vec.dtype.count), src=(x,)*vec.dtype.count)
if x.op is Ops.CONST: return UOp.const(x.dtype.vec(vec.dtype.count), x.arg)
return UOp(Ops.VECTORIZE, x.dtype.vec(vec.dtype.count), src=(x,)*vec.dtype.count)
add, mul, loop_start, loop_end = dvec(add), dvec(mul), dvec(loop_start), dvec(loop_end)
if mul.vmin > 0 and ne is not None:
comprange = UOp.minimum(loop_end, UOp.maximum((add-compval)//mul + (loop_end-loop_start), loop_start))
@@ -209,7 +209,7 @@ def gep_through_wmma(gep:UOp, wmma:UOp):
ssz = prod(x[1] for x in sz)
for w in wmma_idxs: src_args += list(range((w//out_sz)*ssz, (w//out_sz)*ssz + ssz))
tsrcs.append(s.gep(tuple(src_args)))
return UOp(UOps.WMMA, gep.dtype, tuple(tsrcs), wmma.arg)
return UOp(Ops.WMMA, gep.dtype, tuple(tsrcs), wmma.arg)
def no_vectorized_wmma(wmma:UOp):
out_sz = prod(x[1] for x in wmma.arg[6][-1])
@@ -218,9 +218,9 @@ def no_vectorized_wmma(wmma:UOp):
for s,sz in zip(wmma.src, wmma.arg[6]):
ssz = prod(x[1] for x in sz)
tsrcs.append([s.gep(tuple(range(grp, grp+ssz))) for grp in range(0, s.dtype.count, ssz)])
wmmas = [UOp(UOps.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)]
wmmas = [UOp(Ops.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)]
wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
return UOp(UOps.VECTORIZE, wmma.dtype, tuple(wmma_ex))
return UOp(Ops.VECTORIZE, wmma.dtype, tuple(wmma_ex))
def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
reduce_parented, reduce_unparented = partition(acc.src[1:], lambda x: x in ret.sparents)
@@ -231,81 +231,81 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp):
for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
return ret
acc_pat, rng_pat = UPat(UOps.DEFINE_ACC, name="acc"), UPat(UOps.RANGE, name="rng")
acc_pat, rng_pat = UPat(Ops.DEFINE_ACC, name="acc"), UPat(Ops.RANGE, name="rng")
rng_aug = UPat.any(rng_pat, UPat.var("add")+rng_pat, UPat.var("mul")*rng_pat, UPat.var("add")+UPat.var("mul")*rng_pat)
index_load = UPat.var("buf").index(rng_aug).load(name="ld")
arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(UOps.VECTORIZE, name="vec", src=rng_aug))
arange_m = arange_augrng.lt(UPat.cvar("compval")).ne(UPat(UOps.CONST, name="ne", arg=True)).where(UPat.cvar("multconst"), UPat.const(None, 0))
arange_augrng = UPat.any(rng_aug, rng_aug+UPat.var("idx2"), rng_aug+UPat.var("idx2")+UPat.var("idx3"), UPat(Ops.VECTORIZE, name="vec", src=rng_aug))
arange_m = arange_augrng.lt(UPat.cvar("compval")).ne(UPat(Ops.CONST, name="ne", arg=True)).where(UPat.cvar("multconst"), UPat.const(None, 0))
# this is symbolic 2.0
sym = symbolic_flat+PatternMatcher([
# self ASSIGN is just self
(UPat(UOps.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
(UPat(Ops.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
# ASSIGN to global is just self
(UPat(UOps.ASSIGN, src=(UPat(UOps.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x),
(UPat(Ops.ASSIGN, src=(UPat(Ops.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x),
# VECTORIZE/CONST, VECTORIZE/GEP
(UPat(UOps.VECTORIZE, src=UPat(UOps.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
(UPat(UOps.VECTORIZE, src=UPat(UOps.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
# reorder ALU/VECTORIZE
(UPat(UOps.ALU, src=(UPat(UOps.VECTORIZE, src=UPat(name='x')), UPat(UOps.VECTORIZE, src=UPat(name='y'))), name='alu'),
lambda x,y,alu: UOp(UOps.VECTORIZE, alu.dtype, (UOp(UOps.ALU, alu.dtype.scalar(), (x,y), alu.arg),)*alu.dtype.count)),
(UPat(Ops.ALU, src=(UPat(Ops.VECTORIZE, src=UPat(name='x')), UPat(Ops.VECTORIZE, src=UPat(name='y'))), name='alu'),
lambda x,y,alu: UOp(Ops.VECTORIZE, alu.dtype, (UOp(Ops.ALU, alu.dtype.scalar(), (x,y), alu.arg),)*alu.dtype.count)),
# VECTORIZE of a single element is just that element
(UPat(UOps.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
# VECTORIZE void is SINK
(UPat(UOps.VECTORIZE, dtype=dtypes.void, src=UPat(UOps.BARRIER, name='b')), lambda b: b),
(UPat(UOps.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(UOps.SINK, dtypes.void, x.src)),
(UPat(Ops.VECTORIZE, dtype=dtypes.void, src=UPat(Ops.BARRIER, name='b')), lambda b: b),
(UPat(Ops.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(Ops.SINK, dtypes.void, x.src)),
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
(UPat(UOps.GEP, src=(UPat(UOps.GEP, name='g2'),), name='g1'),
(UPat(Ops.GEP, src=(UPat(Ops.GEP, name='g2'),), name='g1'),
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
(UPat(UOps.GEP, src=(UPat(UOps.VECTORIZE, name="vec"),), name="gep"),
lambda gep, vec: UOp(UOps.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
(UPat(UOps.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
(UPat(UOps.GEP, src=(UPat(UOps.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
(UPat(Ops.GEP, src=(UPat(Ops.VECTORIZE, name="vec"),), name="gep"),
lambda gep, vec: UOp(Ops.VECTORIZE, gep.dtype, tuple(vec.src[i] for i in gep.arg)) if len(gep.arg) > 1 else vec.src[gep.arg[0]]),
(UPat(Ops.GEP, src=(UPat.cvar("c", vec=False),), name="gep"), lambda gep, c: gep.const_like(c.arg)),
(UPat(Ops.GEP, src=(UPat(Ops.VCONST, name="c"),), name="gep"), lambda gep, c: gep.const_like(tuple(c.arg[x] for x in gep.arg))),
# push all GEPs through ALUs (fix arange stuff)
(UPat(UOps.GEP, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST), name='alu'),), name='gep'),
(UPat(Ops.GEP, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST), name='alu'),), name='gep'),
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg)),
# push some GEPs through WMMAs
(UPat(UOps.GEP, src=(UPat(UOps.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
(UPat(Ops.GEP, src=(UPat(Ops.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
# tensor core with a 0 input is acc
(UPat(UOps.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc),
(UPat(UOps.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc),
(UPat(Ops.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc),
(UPat(Ops.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc),
# tensor core cleanups
(UPat.var("add") + UPat(UOps.WMMA, name="wmma"),
(UPat.var("add") + UPat(Ops.WMMA, name="wmma"),
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
# threefry
(UPat(UOps.ALU, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key")), arg=BinaryOps.THREEFRY), threefry2x32),
(UPat(Ops.ALU, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key")), arg=BinaryOps.THREEFRY), threefry2x32),
# arange loop folding
(acc_pat.assign(UPat.any(arange_m, arange_m+UPat.var("extra"))+acc_pat), loop_collapse),
# indexing, with cast or where
(acc_pat.assign(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
(acc_pat.assign(UPat.var("idx").eq(UPat(UOps.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse),
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse),
# parentless reduce
(acc_pat.assign(UPat(UOps.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.ADD, name="alu")), reduce_collapse),
(acc_pat.assign(UPat(UOps.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.MAX, name="alu")), reduce_collapse),
(acc_pat.assign(UPat(Ops.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.ADD, name="alu")), reduce_collapse),
(acc_pat.assign(UPat(Ops.ALU, src=[acc_pat, UPat.var("ret")], arg=BinaryOps.MAX, name="alu")), reduce_collapse),
# ** self folding **
(UPat(UOps.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
(UPat(UOps.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
(UPat(Ops.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
(UPat(Ops.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
# x!=0 -> (bool)x
(UPat.var("x").ne(0), lambda x: x.cast(dtypes.bool.vec(x.dtype.count))),
# ** load/store folding **
(UPat.store(UPat(UOps.INDEX, name="index"), UPat.load(UPat(UOps.INDEX, name="index"))), lambda index: UOp(UOps.NOOP)),
(UPat.store(UPat(UOps.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(UOps.INDEX, name="index")))),
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.load(UPat(Ops.INDEX, name="index"))), lambda index: UOp(Ops.NOOP)),
(UPat.store(UPat(Ops.INDEX, name="index"), UPat.var("gate").where(UPat.var("alt"), UPat.load(UPat(Ops.INDEX, name="index")))),
lambda index, gate, alt: UOp.store(index.src[0].index(index.src[1], gate), alt)),
# fold gated LOAD/STORE
(UPat().index(UPat(), UPat.const(dtypes.bool, True)).named("idx"), lambda idx: idx.replace(src=idx.src[0:2])), # remove True
(UPat().index(UPat(), UPat.const(dtypes.bool, False)).named("idx"), lambda idx: idx.const_like(0)), # False -> NULL pointer
(UPat(UOps.LOAD, src=(UPat.const(None, 0),), allow_any_len=True, name="x"), lambda x: x.const_like(0)), # NULL pointer load loads 0
(UPat(UOps.STORE, src=(UPat.const(None, 0),), allow_any_len=True), lambda: UOp(UOps.NOOP)), # NULL pointer store does nothing
(UPat(Ops.LOAD, src=(UPat.const(None, 0),), allow_any_len=True, name="x"), lambda x: x.const_like(0)), # NULL pointer load loads 0
(UPat(Ops.STORE, src=(UPat.const(None, 0),), allow_any_len=True), lambda: UOp(Ops.NOOP)), # NULL pointer store does nothing
# remove NOOPs from SINK
(UPat(UOps.SINK, name="root"),
lambda root: UOp(UOps.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not UOps.NOOP)) != len(root.src) else None),
(UPat(Ops.SINK, name="root"),
lambda root: UOp(Ops.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not Ops.NOOP)) != len(root.src) else None),
# remove EXPANDs from SINK/BARRIER
(UPat(UOps.BARRIER, src=(UPat((UOps.VECTORIZE, UOps.SINK), name='sink'),)), lambda sink: UOp(UOps.BARRIER, dtypes.void, sink.src)),
(UPat(UOps.SINK, name="root"),
lambda root: UOp(UOps.SINK, root.dtype, tuple(flatten(x.src if x.op in {UOps.SINK, UOps.EXPAND} else (x,) for x in root.src)), root.arg)
if any(x.op in {UOps.SINK, UOps.EXPAND} for x in root.src) else None),
(UPat(Ops.BARRIER, src=(UPat((Ops.VECTORIZE, Ops.SINK), name='sink'),)), lambda sink: UOp(Ops.BARRIER, dtypes.void, sink.src)),
(UPat(Ops.SINK, name="root"),
lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in {Ops.SINK, Ops.EXPAND} else (x,) for x in root.src)), root.arg)
if any(x.op in {Ops.SINK, Ops.EXPAND} for x in root.src) else None),
])
# *** uop expander ***
@@ -325,10 +325,10 @@ def _swizzle_args(cargs:Tuple[Tuple[int, int], ...], eargs:Tuple[Tuple[int, int]
return [_expand_arg_to_idx(eargs, {**rpk, **{x:0 for x in exclude_args}} if exclude_args else rpk) for rpk in _choices_from_args(cargs)]
def do_expand(root:UOp):
expands = [x for x in root.src if x.op is UOps.EXPAND]
expands = [x for x in root.src if x.op is Ops.EXPAND]
if len(expands) == 0: return None
# NOTE: we 0 out the reduce axis for WMMA. in theory they should all be the same, but is this always correct?
exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is UOps.WMMA else ()
exclude_args = tuple(dedup(root.arg[-1] + tuple(y[0] for y in flatten(root.arg[-2])))) if root.op is Ops.WMMA else ()
if all_same(expands_args:=[x.arg for x in expands]) and len(exclude_args) == 0:
# if there's only one expand arg, it's okay to use it (optimization)
expand_args = expands[0].arg
@@ -338,8 +338,8 @@ def do_expand(root:UOp):
expand_sz = prod([x[1] for x in expand_args])
new_srcs = []
for i,src in enumerate(root.src):
if src.op is UOps.EXPAND:
if root.op is UOps.IF and i == 0:
if src.op is Ops.EXPAND:
if root.op is Ops.IF and i == 0:
# IF means OR on first arg to IF
new_srcs.append(functools.reduce(operator.__or__, [src.src[0].gep(i) for i in range(expand_sz)]))
elif expand_args == src.arg:
@@ -352,70 +352,70 @@ def do_expand(root:UOp):
new_srcs.append(src.src[0].gep(tuple(lst)))
else:
# non-EXPAND input
if (root.op is UOps.IF) or (root.op is UOps.REDUCE and i != 0):
if (root.op is Ops.IF) or (root.op is Ops.REDUCE and i != 0):
# for the first arg of IF and the RANGE args of REDUCE, just pass them through ignoring EXPANDS
new_srcs.append(src)
elif src.dtype.count > 1:
# put any input dtype > 1 grouped together
new_srcs.append(UOp(UOps.VECTORIZE,
new_srcs.append(UOp(Ops.VECTORIZE,
src.dtype.scalar().vec(expand_sz*src.dtype.count), tuple(src.gep(i) for i in range(src.dtype.count))*expand_sz))
else:
# repeat the arg
new_srcs.append(src.broadcast(expand_sz))
new_arg = root.arg
if root.op is UOps.GEP:
if root.op is Ops.GEP:
assert root.dtype.count == 1
# is this right?
new_arg = tuple(range(root.arg[0], new_srcs[0].dtype.count, new_srcs[0].dtype.count // expand_sz))
nsrc = UOp(root.op, root.dtype.scalar().vec(root.dtype.count*expand_sz), tuple(new_srcs), new_arg)
return UOp(UOps.EXPAND, root.dtype, (nsrc,), expand_args)
return UOp(Ops.EXPAND, root.dtype, (nsrc,), expand_args)
def do_contract(con:UOp):
ex = con.src[0]
# CONTRACT without EXPAND repeats the element VECTORIZED
if ex.op is not UOps.EXPAND: return UOp(UOps.VECTORIZE, con.dtype, con.src*con.dtype.count)
if ex.op is not Ops.EXPAND: return UOp(Ops.VECTORIZE, con.dtype, con.src*con.dtype.count)
# CONTRACT may remove several axes from EXPAND
assert con.dtype.count == prod([x[1] for x in con.arg]), "dtype is wrong"
idxs = []
for rpk in _choices_from_args(new_ex_args:=tuple(x for x in ex.arg if x not in con.arg)):
idxs += [_expand_arg_to_idx(ex.arg, {**rpk, **lrpk}) for lrpk in _choices_from_args(con.arg)]
return UOp(UOps.EXPAND, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args)
return UOp(Ops.EXPAND, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args)
def no_vectorized_alu(alu):
if alu.dtype.vcount == 1: return None
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
return UOp(UOps.VECTORIZE, alu.dtype, alus)
return UOp(Ops.VECTORIZE, alu.dtype, alus)
def create_gate(root:UOp) -> Optional[UOp]:
@functools.lru_cache(None)
def _gate_srcs(u:UOp, gate:UOp) -> UOp:
if u.op is UOps.BARRIER: return u
if u.op is UOps.LOAD and u.src[-1].op is UOps.BARRIER:
return UOp(u.op, u.dtype, u.src[:-1]+(UOp(UOps.IF, dtypes.void, (gate, u.src[-1])),), u.arg)
if u.op is Ops.BARRIER: return u
if u.op is Ops.LOAD and u.src[-1].op is Ops.BARRIER:
return UOp(u.op, u.dtype, u.src[:-1]+(UOp(Ops.IF, dtypes.void, (gate, u.src[-1])),), u.arg)
return u if (replace_source:=tuple(_gate_srcs(x, gate) for x in u.src)) == u.src else UOp(u.op, u.dtype, replace_source, u.arg)
idx = root.src[0]
if idx.op is UOps.CAST: idx = idx.src[0]
return None if idx.op is not UOps.INDEX or len(idx.src) == 2 or (ret:=_gate_srcs(root, idx.src[2])) is root else ret
if idx.op is Ops.CAST: idx = idx.src[0]
return None if idx.op is not Ops.INDEX or len(idx.src) == 2 or (ret:=_gate_srcs(root, idx.src[2])) is root else ret
expander = PatternMatcher([
# double expand
(UPat(UOps.EXPAND, name="outer", src=(UPat(UOps.EXPAND, name="inner"),)),
lambda outer, inner: UOp(UOps.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
(UPat(Ops.EXPAND, name="outer", src=(UPat(Ops.EXPAND, name="inner"),)),
lambda outer, inner: UOp(Ops.EXPAND, outer.dtype, (inner.src[0],), inner.arg+outer.arg)),
# do expansion
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.GEP, UOps.WMMA, UOps.LOAD, UOps.STORE, UOps.INDEX, UOps.ASSIGN,
UOps.VECTORIZE, UOps.REDUCE, UOps.IF), name="root", custom_early_reject=set([(UOps.EXPAND, None)])), do_expand),
(UPat(UOps.CONTRACT, name="con"), do_contract),
(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.GEP, Ops.WMMA, Ops.LOAD, Ops.STORE, Ops.INDEX, Ops.ASSIGN,
Ops.VECTORIZE, Ops.REDUCE, Ops.IF), name="root", custom_early_reject=set([(Ops.EXPAND, None)])), do_expand),
(UPat(Ops.CONTRACT, name="con"), do_contract),
# vectorize DEFINE_ACC
(UPat(UOps.VECTORIZE, src=UPat(UOps.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype)),
(UPat(Ops.VECTORIZE, src=UPat(Ops.DEFINE_ACC, name="acc"), name="v"), lambda acc,v: acc.replace(dtype=v.dtype)),
# BARRIERs aren't actually expanded
(UPat(UOps.BARRIER, src=(UPat(UOps.EXPAND, name="ex"),)),
lambda ex: UOp(UOps.EXPAND, dtypes.void, (UOp(UOps.BARRIER, dtypes.void, ex.src),)*len(ex.src), ex.arg)),
(UPat(Ops.BARRIER, src=(UPat(Ops.EXPAND, name="ex"),)),
lambda ex: UOp(Ops.EXPAND, dtypes.void, (UOp(Ops.BARRIER, dtypes.void, ex.src),)*len(ex.src), ex.arg)),
# empty EXPAND is NOOP
(UPat(UOps.EXPAND, src=(UPat.var('x'),), arg=()), lambda x: x),
(UPat(Ops.EXPAND, src=(UPat.var('x'),), arg=()), lambda x: x),
# EXPAND GEP (needed for WMMA, generalize this) -> vectorized ALU
(UPat(UOps.EXPAND, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
lambda ex,x,y: UOp(UOps.EXPAND, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
(UPat(Ops.EXPAND, name="ex", src=tuple(UPat.var('x').gep(i)+UPat.var('y').gep(i) for i in range(256 if AMX else 8))),
lambda ex,x,y: UOp(Ops.EXPAND, ex.dtype, tuple((x+y).gep(i) for i in range(256 if AMX else 8)), ex.arg)),
])
def no_vectorized_load_store(ls:UOp):
@@ -423,75 +423,75 @@ def no_vectorized_load_store(ls:UOp):
assert isinstance(idx.dtype, PtrDType)
if idx.dtype.v == 1: return None
tv = [UOp(ls.op, ls.dtype.scalar(), tuple(j.gep(i) for j in ls.src)) for i in range(idx.dtype.v)]
return UOp(UOps.VECTORIZE, ls.dtype, tuple(tv))
return UOp(Ops.VECTORIZE, ls.dtype, tuple(tv))
def no_vectorized_acc(acc:UOp):
if acc.dtype.count == 1: return None
alus = tuple(UOp(acc.op, acc.dtype.scalar(),
tuple(s.gep(i) if j == 0 else s for j,s in enumerate(acc.src)), acc.arg+(i,)) for i in range(acc.dtype.count))
return UOp(UOps.VECTORIZE, acc.dtype, alus)
return UOp(Ops.VECTORIZE, acc.dtype, alus)
devectorize = PatternMatcher([
# no ALU on vectorized dtypes
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.INDEX), name="alu"), no_vectorized_alu),
(UPat(UOps.WMMA, name="wmma"), no_vectorized_wmma),
(UPat(UOps.DEFINE_ACC, name="acc"), no_vectorized_acc),
(UPat((UOps.LOAD, UOps.STORE), name="ls"), no_vectorized_load_store),
(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.INDEX), name="alu"), no_vectorized_alu),
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
(UPat(Ops.DEFINE_ACC, name="acc"), no_vectorized_acc),
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
])
def delete_redundant_gates(buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:Optional[UOp]=None) -> Optional[UOp]:
if store_gate not in [gate.src[0] for gate in val.sparents if gate.op is UOps.IF]: return None
if store_gate not in [gate.src[0] for gate in val.sparents if gate.op is Ops.IF]: return None
# remove the gate from the index
return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val)
load_store_indexing = PatternMatcher([
# late fixup of unfoldable image loads
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
(UPat(Ops.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), fix_unfoldable_image_load),
# simplify valid
(UPat(UOps.ALU, name="valid", arg=BinaryOps.AND), simplify_valid),
(UPat(Ops.ALU, name="valid", arg=BinaryOps.AND), simplify_valid),
# image load valid idx simplification
(UPat(UOps.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
# delete_redundant_gates (after expand)
(UPat(UOps.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
(UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
UPat.var("val"))), delete_redundant_gates),
])
def idx_load_store(x:UOp):
idx = x.src[0].index(x.src[1], x.src[3] if len(x.src) > 3 else None)
v = x.dtype.count if x.op is UOps.LOAD else x.src[2].dtype.count
v = x.dtype.count if x.op is Ops.LOAD else x.src[2].dtype.count
if v > 1 and not isinstance(x.src[0].dtype, ImageDType): idx = idx.cast(idx.dtype.base.vec(v).ptr(idx.dtype.local))
post_mask = x.src[4:] if len(x.src) > 3 else (x.src[2:] if x.op is UOps.LOAD else x.src[3:])
if x.op is UOps.LOAD: return UOp(x.op, x.dtype, (idx,)+post_mask, x.arg)
post_mask = x.src[4:] if len(x.src) > 3 else (x.src[2:] if x.op is Ops.LOAD else x.src[3:])
if x.op is Ops.LOAD: return UOp(x.op, x.dtype, (idx,)+post_mask, x.arg)
return UOp(x.op, x.dtype, (idx,x.src[2])+post_mask, x.arg)
migrate_indexing = PatternMatcher([
# use indexing for LOAD/STORE
(UPat((UOps.LOAD, UOps.STORE), src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store),
(UPat((Ops.LOAD, Ops.STORE), src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)),), allow_any_len=True, name="x"), idx_load_store),
# create gate MUST BE BEFORE expander
(UPat(UOps.STORE, name="root"), create_gate),
(UPat(Ops.STORE, name="root"), create_gate),
])
def move_mask(x:UOp, buf:UOp, idx:UOp, mask:UOp, cast:Optional[UOp]=None) -> UOp:
# this moves the mask from the indexing to the load/store op for rendering
nidx = buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx)
return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is UOps.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:])
return UOp.load(nidx, x.const_like(0), mask, *x.src[1:], dtype=x.dtype) if x.op is Ops.LOAD else UOp.store(nidx, x.src[1], mask, *x.src[2:])
pm_render = PatternMatcher([
# for rendering, we use explicit VECTORIZE
(UPat(UOps.CONST, name='c'),
lambda c: UOp(UOps.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
(UPat(UOps.VCONST, name='c'), lambda c: UOp(UOps.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
(UPat(UOps.GEP, name='gep'), lambda gep: UOp(UOps.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
(UPat(UOps.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
(UPat(Ops.CONST, name='c'),
lambda c: UOp(Ops.VECTORIZE, c.dtype, (UOp.const(c.dtype.scalar(), c.arg),)*c.dtype.vcount) if c.dtype.vcount > 1 else None),
(UPat(Ops.VCONST, name='c'), lambda c: UOp(Ops.VECTORIZE, c.dtype, tuple(UOp.const(c.dtype.scalar(), x) for x in c.arg))),
(UPat(Ops.GEP, name='gep'), lambda gep: UOp(Ops.VECTORIZE, gep.dtype, tuple(gep.src[0].gep(x) for x in gep.arg)) if len(gep.arg) > 1 else None),
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
# move masks of loads/stores
(UPat((UOps.LOAD, UOps.STORE), src=(UPat.any(masked_index:=UPat(UOps.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask"))),
(UPat((Ops.LOAD, Ops.STORE), src=(UPat.any(masked_index:=UPat(Ops.INDEX, src=(UPat(name="buf"), UPat(name="idx"), UPat(name="mask"))),
masked_index.cast(None).named("cast")),), allow_any_len=True, name="x"), move_mask),
])
# *** uop graph ***
def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
assert sink.op is UOps.SINK, f"sink isn't sink, it's {sink.op}"
assert sink.op is Ops.SINK, f"sink isn't sink, it's {sink.op}"
supported_ops = tuple(opts.code_for_op.keys()) if opts is not None else ()
extra_matcher = opts.extra_matcher if opts is not None and opts.extra_matcher is not None else PatternMatcher([])

View File

@@ -3,7 +3,7 @@ from collections import defaultdict
from tinygrad.engine.schedule import ScheduleItem
from tinygrad.device import Device, Buffer
from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG
from tinygrad.ops import UOps
from tinygrad.ops import Ops
# **************** memory planning ****************
@@ -47,5 +47,5 @@ def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]
def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
assigned = _internal_memory_planner([si.bufs for si in schedule],
noopt_buffers={b for si in schedule if si.ast.op is not UOps.SINK for b in si.bufs})
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs})
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.assign_preloads) for si in schedule]

View File

@@ -2,7 +2,7 @@ from typing import List, Dict, Optional, cast, Generator, Tuple
import time, pprint
from dataclasses import dataclass, replace
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context, TRACEMETA
from tinygrad.ops import UOps, UOp, Variable, sym_infer, sint
from tinygrad.ops import Ops, UOp, Variable, sym_infer, sint
from tinygrad.dtype import dtypes
from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, Program
@@ -181,18 +181,18 @@ class ExecItem:
return et
def lower_schedule_item(si:ScheduleItem) -> ExecItem:
assert len(set(x.device for x in si.bufs)) == 1 or si.ast.op is UOps.COPY
if si.ast.op is UOps.SINK:
assert len(set(x.device for x in si.bufs)) == 1 or si.ast.op is Ops.COPY
if si.ast.op is Ops.SINK:
runner = get_runner(si.outputs[0].device, si.ast)
return ExecItem(runner, [si.bufs[x] for x in runner.p.globals], si.metadata)
out, arg = si.outputs[0], si.ast.arg
if si.ast.op is UOps.COPY:
if si.ast.op is Ops.COPY:
kernel_type = BufferCopy
if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]:
kernel_type = BufferXfer
return ExecItem(kernel_type(arg, out.device, si.inputs[0].device), list(si.bufs))
if si.ast.op is UOps.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
if si.ast.op is UOps.BUFFER_VIEW: return ExecItem(ViewOp(out), list(si.bufs))
if si.ast.op is Ops.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
if si.ast.op is Ops.BUFFER_VIEW: return ExecItem(ViewOp(out), list(si.bufs))
raise RuntimeError(f"don't know how to lower {si.ast}")
def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:

View File

@@ -2,7 +2,7 @@ import sys, atexit, functools, itertools
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast
from tinygrad.ops import BUFFER_UOPS, MetaOps, ReduceOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint
from tinygrad.ops import BUFFER_UOPS, MetaOps, ReduceOps, UnaryOps, UOp, Ops, PatternMatcher, UPat, Variable, graph_rewrite, track_rewrites, sint
from tinygrad.helpers import DEBUG, Metadata, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.shape.shapetracker import ShapeTracker
@@ -15,7 +15,7 @@ from tinygrad.device import Buffer
sys.setrecursionlimit(10000)
BUF_LIMIT = {"METAL":32}
METAOPS = {MetaOps.COPY:UOps.COPY, MetaOps.EMPTY:UOps.EMPTY, MetaOps.VIEW:UOps.BUFFER_VIEW}
METAOPS = {MetaOps.COPY:Ops.COPY, MetaOps.EMPTY:Ops.EMPTY, MetaOps.VIEW:Ops.BUFFER_VIEW}
# **** ScheduleItem return type
@@ -34,7 +34,7 @@ class ScheduleItem:
"""Read only buffers in the schedule."""
return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs)
@functools.cached_property
def output_idxs(self) -> Tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is UOps.SINK else (0,)
def output_idxs(self) -> Tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
# **** small wrapper for LazyBuffer -> UOp
@@ -62,21 +62,21 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) ->
# consts are always fused and generated
if buf.op is MetaOps.CONST:
if isinstance(val:=buf.arg, UOp): ctx.var_vals.update([val.unbind()])
return UOp(UOps.VALID, dtypes.bool, (buf.st.to_uop(),)).where(v:=UOp.const(dtype, buf.arg), v.const_like(0))
return UOp(Ops.VALID, dtypes.bool, (buf.st.to_uop(),)).where(v:=UOp.const(dtype, buf.arg), v.const_like(0))
# everything else has BUFFER
ubuf = ctx.buf_uops.setdefault(b:=buf.buffer, UOp(UOps.BUFFER, b.dtype.ptr(), (), (len(ctx.buf_uops), (b.device, b.size, b.dtype))))
ubuf = ctx.buf_uops.setdefault(b:=buf.buffer, UOp(Ops.BUFFER, b.dtype.ptr(), (), (len(ctx.buf_uops), (b.device, b.size, b.dtype))))
# if the buffer is already realized we just load it
if buf.is_realized(): return UOp(UOps.PRELOAD, dtype, (ubuf, buf.st.to_uop()))
if buf.is_realized(): return UOp(Ops.PRELOAD, dtype, (ubuf, buf.st.to_uop()))
# everything else needs sources
src = tuple(to_uop(x, ctx, cache) for x in buf.srcs)
if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg)
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, src)
elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (ubuf, src[1]), buf.arg)
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(Ops.CONTIGUOUS, dtype, src)
elif buf.op is MetaOps.ASSIGN: ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg)
elif buf.op in METAOPS: ret = UOp(METAOPS[cast(MetaOps, buf.op)], buf.dtype, (ubuf, *src), buf.arg)
elif buf.op is UnaryOps.CAST: ret = UOp(UOps.CAST, dtype, src)
elif buf.op is UnaryOps.BITCAST: ret = UOp(UOps.BITCAST, dtype, src)
else: ret = UOp(UOps.ALU, dtype, src, buf.op)
cache[buf] = ret = UOp(UOps.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret)))
elif buf.op is UnaryOps.CAST: ret = UOp(Ops.CAST, dtype, src)
elif buf.op is UnaryOps.BITCAST: ret = UOp(Ops.BITCAST, dtype, src)
else: ret = UOp(Ops.ALU, dtype, src, buf.op)
cache[buf] = ret = UOp(Ops.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret)))
if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata
if buf.forced_realize: ctx.realizes[ubuf] = ubuf
return ret
@@ -87,7 +87,7 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, cache:Dict[LazyBuffer, UOp]) ->
def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp:
if (n:=cache.get(u)) is not None: return n
if u.op is UOps.VIEW: return u.replace(arg=apply_to_st(u.arg))
if u.op is Ops.VIEW: return u.replace(arg=apply_to_st(u.arg))
if len(u.src) == 0 or (u.st is not None and u.st == apply_to_st(u.st)): return u
cache[u] = ret = u.replace(src=tuple(st_fixup(x, apply_to_st, cache) for x in u.src))
return ret
@@ -123,7 +123,7 @@ def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp:
return swizzle.src[0].r(root.arg[0], new_axis).view(ShapeTracker.from_shape(output_shape))
def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
swizzles = [x for x in root.src if x.op is UOps.VIEW and len(x.src) != 0]
swizzles = [x for x in root.src if x.op is Ops.VIEW and len(x.src) != 0]
if len(swizzles) == 0: return None
swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
@@ -131,34 +131,34 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
fixup_cache: Dict[UOp, UOp] = {}
new_srcs = [x.src[0] if x in swizzles else st_fixup(x, lambda st:st.reshape(new_input_shape), fixup_cache) for x in root.src]
ret = UOp(root.op, root.dtype, tuple(new_srcs), root.arg)
return ret if ret.op is UOps.STORE else ret.view(ShapeTracker.from_shape(new_shape))
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
assert not any(x.op is UOps.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
return first_reduce.src[0].r(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)
merge_views = PatternMatcher([(UPat(UOps.VIEW, src=(UPat(UOps.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))])
merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))])
# push VIEW to loads
view_left = merge_views+PatternMatcher([
# view before ALU
(UPat(UOps.VIEW, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.CONTIGUOUS, *BUFFER_UOPS), name="e"),), name="v"),
(UPat(Ops.VIEW, src=(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, *BUFFER_UOPS), name="e"),), name="v"),
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
])
# push VIEW to stores
view_right = merge_views+PatternMatcher([
# ASSIGN can override st
(UPat(UOps.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(UOps.ASSIGN, name="a"))),
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.ASSIGN, name="a"))),
lambda a,b,st: UOp.store(b, (a.arg[0]+st.arg).to_uop(), a.replace(arg=())) if a.arg else None),
# VIEW on a reduce creates a new VIEW
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r),
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r),
# push a VIEW down to STORE, through a reduce (ONLY reshapes)
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
# push VIEW(s) down to STORE, through an elementwise op (ONLY reshapes)
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.CONTIGUOUS, UOps.STORE), name="root"), push_swizzle_down_through_elementwise),
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
(UPat((Ops.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])
# ** ScheduleItem context builder
@@ -180,18 +180,18 @@ def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
ctx.bufs.append(x)
return UOp(UOps.DEFINE_GLOBAL, x.dtype, (), len(ctx.bufs)-1)
append_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), _append_buf)])
return UOp(Ops.DEFINE_GLOBAL, x.dtype, (), len(ctx.bufs)-1)
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
if b in ctx.assigned: ctx.assign_preloads.append(b)
return x.replace(op=UOps.LOAD)
return x.replace(op=Ops.LOAD)
to_si = PatternMatcher([
(UPat(UOps.VIEW, name="x"), _append_st_vars),
(UPat(UOps.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),
(UPat(UOps.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
(UPat(UOps.SINK, src=(UPat.store(UPat(), UPat(), UPat(tuple(METAOPS.values()), name="x")),)), lambda ctx,x: x),
(UPat(Ops.VIEW, name="x"), _append_st_vars),
(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
(UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat(tuple(METAOPS.values()), name="x")),)), lambda ctx,x: x),
])
# ** fusion
@@ -206,15 +206,15 @@ def full_ast_rewrite(pre:UOp, var_vals:Dict[Variable, int], assigned:Set[UOp]) -
# fuse and fold store -> loads
sink = graph_rewrite(pre, lazy+multioutput if len(pre.src)>1 else lazy, {x.src[0]:x.src[2] for x in pre.src})
# assert cyclic dependency
for b,ops in itertools.groupby((x for x in sink.sparents if x.op in {UOps.PRELOAD,UOps.LOAD} and x.src[0] in assigned), key=lambda x:x.src[0]):
for b,ops in itertools.groupby((x for x in sink.sparents if x.op in {Ops.PRELOAD,Ops.LOAD} and x.src[0] in assigned), key=lambda x:x.src[0]):
if not all_same([x.op for x in ops]):
raise RuntimeError(f"cycle detected in kernel.\nhelp: use .contiguous() to break the part loading pre-assign {b} into a different kernel.")
# do movementops
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is UOps.ASSIGN]) != 0:
if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is Ops.ASSIGN]) != 0:
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is UOps.PRELOAD and x.src[0] in assign_targets):
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is Ops.PRELOAD and x.src[0] in assign_targets):
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
# convert to AST
@@ -232,13 +232,13 @@ if getenv("RUN_PROCESS_REPLAY"):
def realize(ctx:Dict[UOp, UOp], b:UOp, load:UOp, store:UOp) -> UOp:
ctx[b] = store
return UOp(UOps.LOAD, load.dtype, (b, load.st_arg.to_uop()))
return UOp(Ops.LOAD, load.dtype, (b, load.st_arg.to_uop()))
def UPatLoadStore(to_store=UPat()): return UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), to_store, name="store"), name="load")
do_realize = PatternMatcher([
# always realize meta ops
(UPatLoadStore(UPat((UOps.ASSIGN, UOps.CONTIGUOUS, *METAOPS.values()))), realize),
(UPat((UOps.COPY, UOps.BUFFER_VIEW), src=(UPat.var("u"), UPat.any(UPatLoadStore(), UPatLoadStore().view(name="v"))), name="root"),
(UPatLoadStore(UPat((Ops.ASSIGN, Ops.CONTIGUOUS, *METAOPS.values()))), realize),
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.var("u"), UPat.any(UPatLoadStore(), UPatLoadStore().view(name="v"))), name="root"),
lambda ctx,root,u,v=None,**kwargs: root.replace(src=(u, realize(ctx,**kwargs) if v is None else realize(ctx,**kwargs).view(v.st))),)
])
break_sched = PatternMatcher([(UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),])

View File

@@ -2,7 +2,7 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
import itertools, functools, random, math, time, multiprocessing, traceback, signal
from collections import defaultdict
from dataclasses import replace
from tinygrad.ops import UOp, UOps, Variable, sym_infer
from tinygrad.ops import UOp, Ops, Variable, sym_infer
from tinygrad.device import Device, Buffer, Compiler
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
from tinygrad.dtype import ImageDType, PtrDType
@@ -88,7 +88,7 @@ def _ensure_buffer_alloc(bufs:List[Buffer]) -> List[Buffer]: return [buf.ensure_
def bufs_from_lin(lin:Kernel, allocate:bool=True) -> List[Buffer]:
bufsts: DefaultDict[int, List[UOp]] = defaultdict(list)
for x in lin.bufs:
if x.src[0].op is UOps.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x)
if x.src[0].op is Ops.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x)
rawbufs: List[Optional[Buffer]] = [None]*len(bufsts)
for k,lx in bufsts.items():
buf_size = prod(dtype.shape) if isinstance(dtype:=lx[0].src[0].dtype, ImageDType) else max(y.st_arg.real_size() for y in lx)

View File

@@ -123,8 +123,8 @@ REDUCE_ALU: Dict[ReduceOps, BinaryOps] = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(dt)}[op], dt)
# the order of these UOps controls the order of the toposort
class UOps(FastEnum):
# the order of these Ops controls the order of the toposort
class Ops(FastEnum):
# uops that aren't rendered
SINK = auto()
CONTIGUOUS = auto()
@@ -183,9 +183,9 @@ class UOps(FastEnum):
VCONST = auto()
CONST = auto()
BUFFER_UOPS = {UOps.LOAD, UOps.PRELOAD, UOps.STORE, UOps.VALID}
BUFFER_UOPS = {Ops.LOAD, Ops.PRELOAD, Ops.STORE, Ops.VALID}
COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)}
END_FOR_UOP = {Ops.IF:(Ops.STORE, Ops.ENDIF), Ops.RANGE:(Ops.ASSIGN, Ops.ENDRANGE)}
# With True as the default, this matches the old symbolic behavior
def resolve(x, default:bool=True):
@@ -218,14 +218,14 @@ def pretty_print(x:Any, rep:Callable, srcfn=lambda x: x.src, cache=None, d=0)->s
class UOpMetaClass(type):
ucache:WeakValueDictionary[Tuple, UOp] = WeakValueDictionary()
def __call__(cls, op:UOps, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None):
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:Tuple[UOp,...]=tuple(), arg:Any=None):
if (ret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None: return ret
UOpMetaClass.ucache[key] = ret = super().__call__(op, dtype, src, arg)
return ret
class UOp(MathTrait, metaclass=UOpMetaClass):
__slots__ = ["op", "dtype", "src", "arg"]
def __init__(self, op:UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
def __init__(self, op:Ops, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
# TODO: instant check rules here make debugging easier
#assert op in UOps and isinstance(dtype, DType), f"bad UOp creation with {op} {dtype}"
#if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool
@@ -243,7 +243,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def key(self) -> bytes:
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else self.arg
@functools.cached_property
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}}
@functools.cached_property # parents with self
@@ -251,31 +251,31 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@functools.cached_property
def tuplize(self:UOp) -> Tuple[int, Any, Optional[DType], Tuple]:
return (self.op.value, self.arg.value if self.op is UOps.ALU else self.arg, self.dtype, tuple(x.tuplize for x in self.src))
return (self.op.value, self.arg.value if self.op is Ops.ALU else self.arg, self.dtype, tuple(x.tuplize for x in self.src))
# *** uop shape stuff ***
@property
def has_st(self) -> bool: return self.op not in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL, UOps.BUFFER, UOps.CONST, UOps.DEFINE_VAR}
def has_st(self) -> bool: return self.op not in {Ops.DEFINE_LOCAL, Ops.DEFINE_GLOBAL, Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR}
@functools.cached_property
def st(self) -> Optional[ShapeTracker]:
if not self.has_st: return None
if self.op in BUFFER_UOPS: return self.st_arg
if self.op is UOps.VIEW: return self.arg
if self.op is Ops.VIEW: return self.arg
src_sts = [x.st for x in self.src if x.st is not None]
assert all_same([x.shape for x in src_sts]), f"UOp parents must have the same shape {self} {[x.shape for x in src_sts]}"
from tinygrad.shape.shapetracker import ShapeTracker
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg)) if self.op is UOps.REDUCE_AXIS else src_sts[0]
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg)) if self.op is Ops.REDUCE_AXIS else src_sts[0]
@functools.cached_property
def full_shape(self) -> Tuple[sint, ...]:
return self.arg.shape if self.op is UOps.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
return self.arg.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
# *** uop evaluation ***
def simplify(self):
with Context(TRACK_MATCH_STATS=0):
return graph_rewrite(self, symbolic)
def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is UOps.CONST else ret
def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
def _eval(self, dtype, expected_type:Type[T]) -> T:
assert self.dtype in dtype, f"eval with wrong dtype {self}"
vmin, vmax = (simple_self:=self.simplify())._min_max
@@ -294,98 +294,98 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@property
def st_arg(self) -> ShapeTracker:
assert self.op in BUFFER_UOPS, f"st_arg called on {self.op}"
ret = self.src[0 if self.op is UOps.VALID else 1]
assert ret.op is UOps.VIEW, f"st_arg trying to return {ret}"
ret = self.src[0 if self.op is Ops.VALID else 1]
assert ret.op is Ops.VIEW, f"st_arg trying to return {ret}"
return ret.arg
@property
def axis_arg(self) -> Tuple[int, ...]:
assert self.op in {UOps.REDUCE_AXIS, UOps.WMMA}, f"axis_arg called on {self.op}"
ret = self.arg[1] if self.op is UOps.REDUCE_AXIS else self.arg[7]
assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}"
ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
assert isinstance(ret, tuple) and all(isinstance(x, int) for x in ret), f"axis_arg trying to return {ret}"
return ret
def sink(self, *srcs:UOp): return UOp(UOps.SINK, dtypes.void, (self,)+srcs)
def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(UOps.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def view(self, st:ShapeTracker): return UOp(UOps.VIEW, self.dtype, (self,), st)
def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs)
def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def view(self, st:ShapeTracker): return UOp(Ops.VIEW, self.dtype, (self,), st)
def const_like(self, b:ConstLike): return UOp.const(self.dtype, b)
def broadcast(self, count:int):
assert self.dtype.count == 1
if count == 1: return self
return UOp(UOps.VECTORIZE, self.dtype.vec(count), (self,)*count)
def cast(self, dtype:DType): return UOp(UOps.CAST, dtype, (self,))
def bitcast(self, dtype:DType): return UOp(UOps.BITCAST, dtype, (self,))
return UOp(Ops.VECTORIZE, self.dtype.vec(count), (self,)*count)
def cast(self, dtype:DType): return UOp(Ops.CAST, dtype, (self,))
def bitcast(self, dtype:DType): return UOp(Ops.BITCAST, dtype, (self,))
def gep(self, i:Union[Tuple[int, ...], int]):
if isinstance(i, int):
# NOTE: these are just shortcuts to not have to create and fold later
if self.op is UOps.VECTORIZE: return self.src[i]
if self.op is UOps.VCONST: return UOp.const(self.dtype.scalar(), self.arg[i])
if self.op is UOps.CONST: return UOp.const(self.dtype.scalar(), self.arg)
if self.op is Ops.VECTORIZE: return self.src[i]
if self.op is Ops.VCONST: return UOp.const(self.dtype.scalar(), self.arg[i])
if self.op is Ops.CONST: return UOp.const(self.dtype.scalar(), self.arg)
i = (i,)
if self.dtype == dtypes.void or (i == tuple(range(len(i))) and self.dtype.vcount == len(i)): return self
assert len(i) >= 1 and all(x < self.dtype.vcount for x in i), f"bad GEP on {self.dtype}, {i}"
return UOp(UOps.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
def load(self, *src:UOp, **kwargs): return UOp(UOps.LOAD, src=(self,)+src, **kwargs)
def store(self, *src:UOp, **kwargs): return UOp(UOps.STORE, dtypes.void, (self,)+src, **kwargs)
return UOp(Ops.GEP, self.dtype.scalar().vec(len(i)) if len(i) > 1 else self.dtype.scalar(), (self,), i)
def load(self, *src:UOp, **kwargs): return UOp(Ops.LOAD, src=(self,)+src, **kwargs)
def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
def alu(self, arg, *src:UOp):
out_dtype = (self, *src)[-1].dtype
if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} and out_dtype is not None:
out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
return UOp(UOps.ALU, out_dtype, (self,)+src, arg)
return UOp(Ops.ALU, out_dtype, (self,)+src, arg)
@staticmethod
def const(dtype:DType, b:ConstLike):
if isinstance(b, UOp): return b.unbind()[0] if b.op is UOps.BIND else b
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
@staticmethod
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
return UOp(Ops.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx)
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
def r(self, op, axis): return UOp(UOps.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in ReduceOps else op, axis))
def assign(self, x:UOp): return UOp(UOps.ASSIGN, self.dtype, (self,x))
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(Ops.REDUCE, self.dtype, (self,) + rng, op)
def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in ReduceOps else op, axis))
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
# *** uop Variable stuff ***
@staticmethod
def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.int):
assert not isinstance(min_val, UOp) and not isinstance(max_val, UOp), f"can't create Variable {name} with {min_val}/{max_val}"
return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
return UOp(Ops.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
@property
def expr(self):
assert self.op is UOps.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
return self.arg[0]
def bind(self, val:int):
assert self.op is UOps.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
assert self.op is Ops.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
return UOp(UOps.BIND, self.dtype, (self, self.const_like(val)))
return UOp(Ops.BIND, self.dtype, (self, self.const_like(val)))
def unbind(self) -> Tuple[Variable, int]:
assert self.op is UOps.BIND and self.src[0].op is UOps.DEFINE_VAR and self.src[1].op is UOps.CONST, f"can't unbind {self}"
assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}"
return self.src[0], self.src[1].arg
@property
def val(self) -> int: return self.unbind()[1]
def vars(self) -> Set[UOp]:
bound_vars = set([x for x in self.sparents if x.op is UOps.BIND and x.src[0].op is UOps.DEFINE_VAR])
bound_vars = set([x for x in self.sparents if x.op is Ops.BIND and x.src[0].op is Ops.DEFINE_VAR])
bound_var_base = set(x.src[0] for x in bound_vars)
all_vars = set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
all_vars = set([x for x in self.sparents if x.op is Ops.DEFINE_VAR])
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
def variables(self) -> List[Variable]:
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not UOps.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
# *** uop symbolic stuff ***
def const_factor(self) -> int:
"""largest known int that divides self"""
if self.op is UOps.CONST: return self.arg
if self.op is UOps.VCONST: return functools.reduce(math.gcd, self.arg)
if self.op is UOps.ALU:
if self.op is Ops.CONST: return self.arg
if self.op is Ops.VCONST: return functools.reduce(math.gcd, self.arg)
if self.op is Ops.ALU:
if self.arg is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is UOps.CONST else self.src[1].arg if self.src[1].op is UOps.CONST else 1
if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
return 1
def divides(self, v) -> Optional[UOp]:
if v==1: return self
if self.op is UOps.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
if self.op is UOps.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
if self.op is UOps.ALU:
if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None
if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None
if self.op is Ops.ALU:
if self.arg is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None
if self.arg is BinaryOps.MUL:
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
@@ -398,20 +398,20 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@functools.cached_property
def _min_max(self) -> Tuple[ConstType, ConstType]:
# NOTE: returned UOp is assumed to be CONST
if self.op is UOps.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
if self.op is UOps.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
if self.op is UOps.BIND: return self.src[0].vmin, self.src[0].vmax # ignore the bound value
if self.op in {UOps.EXPAND, UOps.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
if self.op is Ops.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
if self.op is Ops.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
if self.op is Ops.BIND: return self.src[0].vmin, self.src[0].vmax # ignore the bound value
if self.op in {Ops.EXPAND, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
if self.op is UOps.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else dtypes.max(self.dtype)
if self.op is UOps.CONST: return self.arg, self.arg
if self.op is UOps.VCONST: return (min(self.arg), max(self.arg))
if self.op is UOps.ALU and not dtypes.is_float(self.dtype):
if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else dtypes.max(self.dtype)
if self.op is Ops.CONST: return self.arg, self.arg
if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
if self.op is Ops.ALU and not dtypes.is_float(self.dtype):
s0,s1,s2 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(3)]
if self.arg is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax
if self.arg is BinaryOps.MUL: return min(vals:=(s0.vmin*s1.vmin, s0.vmin*s1.vmax, s0.vmax*s1.vmin, s0.vmax*s1.vmax)), max(vals)
if self.arg is BinaryOps.MOD and s1.vmin > 0: return 0, s1.vmax-1
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
if self.arg is BinaryOps.IDIV and s1.op is Ops.CONST:
if s1.arg > 0: return s0.vmin//s1.arg, s0.vmax//s1.arg
if s1.arg < 0 and s0.vmin >= 0: return -(s0.vmax//-s1.arg), -(s0.vmin//-s1.arg)
if self.arg is BinaryOps.MAX: return max(s0.vmin, s1.vmin), max(s0.vmax, s1.vmax)
@@ -430,7 +430,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@functools.cached_property
def _sym_fxn(self):
sself = self.simplify()
varnames = tuple(x.arg[0] for x in sself.sparents if x.op is UOps.DEFINE_VAR)
varnames = tuple(x.arg[0] for x in sself.sparents if x.op is Ops.DEFINE_VAR)
# TODO: sanitize varnames, or don't use naked eval while staying fast
return eval("lambda "+','.join(varnames)+": "+sself.render()), varnames # pylint: disable=eval-used
@@ -440,7 +440,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def render(self, simplify=True) -> str:
ret = graph_rewrite(self.simplify() if simplify else self, renderer)
return ret.arg if ret.op is UOps.NOOP else str(ret)
return ret.arg if ret.op is Ops.NOOP else str(ret)
@dataclass(frozen=True)
class KernelInfo:
@@ -476,7 +476,7 @@ def exec_alu(op:Op, dtype:DType, operands, truncate_output=True):
def print_uops(uops:List[UOp]):
for i,u in enumerate(uops):
formatted_parents = [uops.index(x) if x.op is not UOps.CONST else f"{x.arg}" for x in u.src]
formatted_parents = [uops.index(x) if x.op is not Ops.CONST else f"{x.arg}" for x in u.src]
print(f"{i:4d} {str(u.op):20s}: {str(u.dtype):25s} " f"{str(formatted_parents):32s} {u.arg}")
def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
@@ -487,26 +487,26 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
dont_count: Set[UOp] = set()
if ignore_indexing:
for u in uops:
if u.op in {UOps.LOAD, UOps.STORE}:
if u.op in {Ops.LOAD, Ops.STORE}:
dont_count = dont_count.union(u.src[0].sparents)
if len(u.src) > 2: dont_count = dont_count.union(u.src[2].sparents)
elif u.op is UOps.IF:
elif u.op is Ops.IF:
dont_count = dont_count.union(u.src[0].sparents)
for u in uops:
if u.op is UOps.RANGE:
if u.op is Ops.RANGE:
mult_stack.append(mults)
mults *= (u.src[1] - u.src[0]).ssimplify()
elif u.op is UOps.ENDRANGE:
elif u.op is Ops.ENDRANGE:
mults = mult_stack.pop(-1)
elif u.op is UOps.SPECIAL:
elif u.op is Ops.SPECIAL:
mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these
elif u.op is UOps.LOAD:
elif u.op is Ops.LOAD:
mem += u.dtype.itemsize * mults
elif u.op is UOps.STORE:
elif u.op is Ops.STORE:
mem += u.src[1].dtype.itemsize * mults
elif u.op is UOps.ALU and u not in dont_count:
elif u.op is Ops.ALU and u not in dont_count:
flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count
elif u.op is UOps.WMMA and u not in dont_count:
elif u.op is Ops.WMMA and u not in dont_count:
flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
return flops, mem
@@ -525,11 +525,11 @@ def lines(fn) -> List[str]:
class UPat(MathTrait):
__slots__ = ["op", "dtype", "arg", "name", "src"]
def __init__(self, op:Optional[Union[UOps, Tuple[UOps, ...]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None,
def __init__(self, op:Optional[Union[Ops, Tuple[Ops, ...]]]=None, dtype:Optional[Union[DType, Tuple[DType, ...]]]=None,
src:Optional[Union[Tuple[UPat, ...], List[UPat], UPat]]=None, arg:Any=None,
name:Optional[str]=None, allow_any_len:bool=False, location=None,
custom_early_reject:Optional[Set[Tuple[UOps, Any]]]=None):
self.op: Optional[Tuple[UOps, ...]] = (op,) if isinstance(op, UOps) else op
custom_early_reject:Optional[Set[Tuple[Ops, Any]]]=None):
self.op: Optional[Tuple[Ops, ...]] = (op,) if isinstance(op, Ops) else op
self.dtype: Optional[Tuple[DType, ...]] = (dtype,) if isinstance(dtype, DType) else dtype
self.arg, self.name, self._in_src, self.custom_early_reject = arg, name, src, custom_early_reject
self.src: Any = None
@@ -561,24 +561,24 @@ class UPat(MathTrait):
@staticmethod
@functools.lru_cache(None)
def cvar(name:Optional[str]=None, dtype:Optional[DType]=None, vec=True):
return UPat((UOps.CONST, UOps.VCONST) if vec else UOps.CONST, dtype=dtype, name=name)
return UPat((Ops.CONST, Ops.VCONST) if vec else Ops.CONST, dtype=dtype, name=name)
@staticmethod
def const(dtype:Optional[Union[DType, Tuple[DType, ...]]], b:ConstType): return UPat(UOps.CONST, dtype=dtype, arg=b)
def const(dtype:Optional[Union[DType, Tuple[DType, ...]]], b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
# copied from UOp
def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(UOps.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def view(self, st=None, **kwargs): return UPat(UOps.VIEW, self.dtype, (self,), st, **kwargs)
def cast(self, dtype=None): return UPat(UOps.CAST, dtype, (self,))
def bitcast(self, dtype=None): return UPat(UOps.BITCAST, dtype, (self,))
def gep(self, i:int): return UPat(UOps.GEP, None, (self,), (i,))
def load(self, *src:UPat, **kwargs): return UPat(UOps.LOAD, src=(self,)+src, **kwargs)
def store(self, *src:UPat, **kwargs): return UPat(UOps.STORE, dtypes.void, (self,)+src, **kwargs)
def assign(self, x:UPat): return UPat(UOps.ASSIGN, self.dtype, (self,x))
def index(self, idx:UPat, valid:Optional[UPat]=None): return UPat(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def view(self, st=None, **kwargs): return UPat(Ops.VIEW, self.dtype, (self,), st, **kwargs)
def cast(self, dtype=None): return UPat(Ops.CAST, dtype, (self,))
def bitcast(self, dtype=None): return UPat(Ops.BITCAST, dtype, (self,))
def gep(self, i:int): return UPat(Ops.GEP, None, (self,), (i,))
def load(self, *src:UPat, **kwargs): return UPat(Ops.LOAD, src=(self,)+src, **kwargs)
def store(self, *src:UPat, **kwargs): return UPat(Ops.STORE, dtypes.void, (self,)+src, **kwargs)
def assign(self, x:UPat): return UPat(Ops.ASSIGN, self.dtype, (self,x))
def const_like(self, b:ConstLike): return UPat.const(self.dtype, cast(ConstType, b))
def alu(self, arg, *src:UPat):
asrc = (self,)+src
return UPat(UOps.ALU, None if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else asrc[-1].dtype, list(asrc) if arg in COMMUTATIVE else asrc, arg)
return UPat(Ops.ALU, None if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else asrc[-1].dtype, list(asrc) if arg in COMMUTATIVE else asrc, arg)
def printable(self:UPat) -> str:
try: return lines(self.location[0])[self.location[1]-1].strip()
@@ -627,7 +627,7 @@ class PatternMatcher:
def __init__(self, patterns:List[Tuple[UPat, Callable]]):
self.patterns = patterns
# NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher!
self.pdict: Dict[Tuple[UOps, Any], List[Tuple[UPat, Callable, Set, bool]]] = {}
self.pdict: Dict[Tuple[Ops, Any], List[Tuple[UPat, Callable, Set, bool]]] = {}
# uop is required, arg is optional
for p,fxn in self.patterns:
assert p.op is not None
@@ -745,84 +745,84 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
# this is the matcher for the final rendered UOps
# matcher functions returns True or False (or None to not match)
spec = PatternMatcher([
(UPat(UOps.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
(UPat(UOps.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
(UPat(UOps.DEFINE_ACC, src=(UPat.var("c"),), name="x", allow_any_len=True),
lambda x,c: all(y.op is UOps.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
(UPat(UOps.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
(UPat(Ops.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
(UPat(Ops.DEFINE_ACC, src=(UPat.var("c"),), name="x", allow_any_len=True),
lambda x,c: all(y.op is Ops.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
(UPat(Ops.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
(UPat(UOps.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
(UPat(UOps.SPECIAL, src=()), lambda: True),
(UPat(Ops.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
(UPat(Ops.SPECIAL, src=()), lambda: True),
# TODO: confirm the args of both of these are shapetrackers
(UPat(UOps.VIEW, src=()), lambda: True),
(UPat(UOps.VIEW, src=(UPat(),)), lambda: True),
(UPat(Ops.VIEW, src=()), lambda: True),
(UPat(Ops.VIEW, src=(UPat(),)), lambda: True),
(UPat(UOps.VALID, dtypes.bool, (UPat(UOps.VIEW),)), lambda: True),
(UPat(UOps.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
(UPat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
# early LOAD has a <buf, shapetracker, store?>
(UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(UOps.VIEW))), lambda: True),
(UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(UOps.VIEW), UPat(UOps.STORE))), lambda: True),
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW))), lambda: True),
(UPat(Ops.LOAD, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat(Ops.STORE))), lambda: True),
# early STORE has a <buf, shapetracker, val>
(UPat(UOps.STORE, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(UOps.VIEW), UPat())), lambda: True),
(UPat(Ops.STORE, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(Ops.VIEW), UPat())), lambda: True),
# **** new style load/store ****
# INDEX is used in new style load/store
(UPat(UOps.INDEX, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat())), lambda: True),
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True),
# LOAD takes a <bufidx, alt?, gate?, barrier?>
(UPat(UOps.LOAD, src=(UPat((UOps.INDEX, UOps.CAST)),)), lambda: True),
(UPat(UOps.LOAD, src=(UPat((UOps.INDEX, UOps.CAST)), UPat((UOps.IF, UOps.BARRIER)))), lambda: True),
(UPat(UOps.LOAD, src=(UPat((UOps.INDEX, UOps.CAST)), UPat(name="alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)),)), lambda: True),
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat((Ops.IF, Ops.BARRIER)))), lambda: True),
(UPat(Ops.LOAD, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(name="alt"), UPat(dtype=dtypes.bool)), name="ld"), lambda ld,alt: ld.dtype == alt.dtype),
# STORE takes a <bufidx, val, gate?>
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.INDEX, UOps.CAST)), UPat())), lambda: True),
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.INDEX, UOps.CAST)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.INDEX, UOps.CAST)), UPat(), UPat(UOps.IF))), lambda: True),
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat())), lambda: True),
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat((Ops.INDEX, Ops.CAST)), UPat(), UPat(Ops.IF))), lambda: True),
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
(UPat(UOps.ALU, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y")), arg=TernaryOps.WHERE),
(UPat(Ops.ALU, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y")), arg=TernaryOps.WHERE),
lambda w,x,y: w.dtype == x.dtype == y.dtype),
(UPat(UOps.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPLT), lambda x,y: x.dtype == y.dtype),
(UPat(UOps.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPNE), lambda x,y: x.dtype == y.dtype),
(UPat(Ops.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPLT), lambda x,y: x.dtype == y.dtype),
(UPat(Ops.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPNE), lambda x,y: x.dtype == y.dtype),
# and SHL/SHR, the shift distance is an int
(UPat(UOps.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHL),
(UPat(Ops.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHL),
lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)),
(UPat(UOps.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHR),
(UPat(Ops.ALU, src=(UPat(name="x"), UPat(name="y")), name="alu", arg=BinaryOps.SHR),
lambda alu,x,y: alu.dtype == x.dtype and (x.dtype == y.dtype or y.dtype == dtypes.uint)),
(UPat(UOps.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
(UPat(UOps.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
(UPat(Ops.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
(UPat(Ops.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
(UPat(UOps.ASSIGN, src=(UPat((UOps.DEFINE_ACC, UOps.DEFINE_GLOBAL)), UPat())), lambda: True),
(UPat(UOps.ENDRANGE, dtype=dtypes.void, src=(UPat(UOps.RANGE),)), lambda: True),
(UPat(Ops.ASSIGN, src=(UPat((Ops.DEFINE_ACC, Ops.DEFINE_GLOBAL)), UPat())), lambda: True),
(UPat(Ops.ENDRANGE, dtype=dtypes.void, src=(UPat(Ops.RANGE),)), lambda: True),
# all WMMA has 3 args, <x, w, acc>
(UPat(UOps.WMMA, src=(UPat(), UPat(), UPat())), lambda: True),
(UPat(UOps.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
(UPat(UOps.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
(UPat(Ops.WMMA, src=(UPat(), UPat(), UPat())), lambda: True),
(UPat(Ops.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
(UPat(Ops.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
# if has a <gate, barrier?>
(UPat(UOps.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
(UPat(UOps.IF, dtype=dtypes.void, src=(UPat(), UPat(UOps.BARRIER))), lambda: True),
(UPat(UOps.ENDIF, dtype=dtypes.void, src=(UPat(UOps.IF),)), lambda: True),
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(),)), lambda: True),
(UPat(Ops.IF, dtype=dtypes.void, src=(UPat(), UPat(Ops.BARRIER))), lambda: True),
(UPat(Ops.ENDIF, dtype=dtypes.void, src=(UPat(Ops.IF),)), lambda: True),
(UPat(UOps.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in REDUCE_ALU.values()),
(UPat(UOps.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
(UPat(UOps.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
(UPat((UOps.BITCAST, UOps.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
(UPat(UOps.BARRIER, dtypes.void, src=UPat(UOps.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
(UPat(Ops.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in REDUCE_ALU.values()),
(UPat(Ops.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
(UPat(Ops.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
(UPat((Ops.BITCAST, Ops.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None),
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, allow_any_len=True)), lambda: True), # NOTE: all pointers must be local
# NOTE: for testing, we let sinks be anything
#(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True),
(UPat(UOps.SINK, dtypes.void), lambda: True),
(UPat(UOps.NOOP), lambda: True),
(UPat(Ops.SINK, dtypes.void), lambda: True),
(UPat(Ops.NOOP), lambda: True),
# PTX LOAD/STORE
(UPat((UOps.LOAD, UOps.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
(UPat(UOps.BARRIER, dtypes.void, src=UPat(UOps.STORE, src=(UPat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True),
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
(UPat(Ops.BARRIER, dtypes.void, src=UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True),
])
def type_verify(uops:List[UOp]):
@@ -842,7 +842,7 @@ def cast_float_to_bf16(x: UOp) -> UOp:
# *** most of symbolic lives here now ***
def split_uop(x:UOp, sep:BinaryOps):
if x.op is UOps.ALU and x.arg is sep:
if x.op is Ops.ALU and x.arg is sep:
for s in x.src: yield from split_uop(s, sep)
else: yield x
@@ -859,7 +859,7 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]:
assert divides is not None
remainder.append(divides)
something_changed = True
elif u.op is UOps.ALU and u.arg is BinaryOps.MOD and (s1:=u.src[1]).op is UOps.CONST and s1.arg%c == 0:
elif u.op is Ops.ALU and u.arg is BinaryOps.MOD and (s1:=u.src[1]).op is Ops.CONST and s1.arg%c == 0:
remainder.append(u.src[0])
something_changed = True
else: remainder.append(u)
@@ -874,7 +874,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1
for u in split_uop(x, BinaryOps.ADD):
if u.op is UOps.CONST:
if u.op is Ops.CONST:
# add all const together first
if rem_const != 0: something_changed = True
rem_const += u.arg
@@ -886,7 +886,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
something_changed = True
else:
# divisor is the smallest common divisor of all MULs
if u.op is UOps.ALU and u.arg is BinaryOps.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor
if u.op is Ops.ALU and u.arg is BinaryOps.MUL and factor > 1 and c % factor == 0 and (divisor == 1 or divisor > factor): divisor = factor
remainder.append(u)
gcd = math.gcd(gcd, factor)
@@ -917,11 +917,11 @@ def fold_unrolled_divs(divs:UOp):
# example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x
add_chain, denominator, seen_const, ans = list(split_uop(divs, BinaryOps.ADD)), None, [], None
for u in add_chain:
if not (u.op is UOps.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is UOps.CONST): return None
if not (u.op is Ops.ALU and u.arg is BinaryOps.IDIV and u.src[1].op is Ops.CONST): return None
if denominator is None: denominator = u.src[1].arg
if denominator != u.src[1].arg: return None
# assumed CONST is the last of an ADD
if (s0:=u.src[0]).op is UOps.ALU and s0.arg is BinaryOps.ADD and s0.src[1].op is UOps.CONST and s0.src[1].op is UOps.CONST:
if (s0:=u.src[0]).op is Ops.ALU and s0.arg is BinaryOps.ADD and s0.src[1].op is Ops.CONST and s0.src[1].op is Ops.CONST:
seen_const.append(s0.src[1].arg)
s0 = s0.src[0]
else: seen_const.append(0)
@@ -933,7 +933,7 @@ def fold_unrolled_divs(divs:UOp):
if ans is not None and 0 <= ans.vmin and ans.vmax + i < denominator: seen_const.append(i)
return ans if ans is not None and sorted(seen_const)==list(range(denominator)) else None
def is_irreducible(u:UOp): return u.op in (UOps.CONST, UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE)
def is_irreducible(u:UOp): return u.op in (Ops.CONST, Ops.DEFINE_VAR, Ops.SPECIAL, Ops.RANGE)
def canonicalize_simplex(X:UOp) -> Optional[UOp]:
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
@@ -941,7 +941,7 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]:
changed, ret = False, []
for u in split_uop(X, BinaryOps.ADD):
# assumed the const is the last src of MUL
if u.op is UOps.ALU and u.arg is BinaryOps.MUL and u.src[1].op is UOps.CONST and u.src[1].arg > 0:
if u.op is Ops.ALU and u.arg is BinaryOps.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0:
changed = True
u = u.src[0]
if not (is_irreducible(u) and u.vmin >= 0): return None
@@ -951,8 +951,8 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]:
def is_increasing(f:UOp) -> bool:
# is f a monotonically increasing function regards its input
if is_irreducible(f): return True
if f.op is UOps.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
if f.op is UOps.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is UOps.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
if f.op is Ops.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
if f.op is Ops.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
return False # False if not sure
def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
@@ -960,10 +960,10 @@ def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
# if it's X >= c, returns X, False, c
# (X < c).ne(True) -> X >= c
if valid.op is UOps.ALU and valid.arg is BinaryOps.CMPNE and valid.src[1].op is UOps.CONST and valid.src[1].arg == 1 and \
(s0:=valid.src[0]).op is UOps.ALU and s0.arg is BinaryOps.CMPLT and s0.src[1].op is UOps.CONST: return s0.src[0], False, s0.src[1].arg
if valid.op is Ops.ALU and valid.arg is BinaryOps.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \
(s0:=valid.src[0]).op is Ops.ALU and s0.arg is BinaryOps.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg
# X < c -> X <= c-1
if valid.op is UOps.ALU and valid.arg is BinaryOps.CMPLT and valid.src[1].op is UOps.CONST: return valid.src[0], True, valid.src[1].arg-1
if valid.op is Ops.ALU and valid.arg is BinaryOps.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1
raise ValueError(f"not able to parse {valid=}")
def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
@@ -983,7 +983,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
# every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop
candidates = []
if expr.op is UOps.ALU and expr.arg is BinaryOps.ADD and all(is_irreducible(u) and v[0] == 1 for u in split_uop(expr, BinaryOps.ADD)):
if expr.op is Ops.ALU and expr.arg is BinaryOps.ADD and all(is_irreducible(u) and v[0] == 1 for u in split_uop(expr, BinaryOps.ADD)):
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, BinaryOps.ADD)])
# try checking the whole clause
@@ -992,7 +992,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]:
for candidate in candidates:
# if every branch in candidate gives the same simplified uop, we can rewrite the uop
newuops = [uop.substitute({X:newX}).simplify().substitute({newX:X}).simplify() for X,newX in candidate]
if uop.op is UOps.VECTORIZE and len(uop.src) == 2:
if uop.op is Ops.VECTORIZE and len(uop.src) == 2:
if all_same([uops.src[0] for uops in newuops]): uop = uop.replace(src=(newuops[0].src[0], uop.src[1]))
if all_same([uops.src[1] for uops in newuops]): uop = uop.replace(src=(uop.src[0], newuops[0].src[1]))
elif all_same(newuops): uop = newuops[0]
@@ -1037,17 +1037,17 @@ symbolic_simple = PatternMatcher([
# NOTE: this can be wrong for loaded NaN
(UPat.var("x") * 0, lambda x: x.const_like(float("nan") if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
# ** constant folding **
(UPat(UOps.ALU, name="root", src=UPat((UOps.VCONST, UOps.CONST))),
(UPat(Ops.ALU, name="root", src=UPat((Ops.VCONST, Ops.CONST))),
lambda root: root.const_like(exec_alu(root.arg, root.dtype, [x.arg for x in root.src], truncate_output=False))),
# ** COMMUTATIVE flipping **
*[(UPat(UOps.ALU, arg=op, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None) for op in COMMUTATIVE],
*[(UPat(Ops.ALU, arg=op, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None) for op in COMMUTATIVE],
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
(UPat.var('x', dtype=dtypes.bool) * UPat.var('y', dtype=dtypes.bool), lambda x,y: x&y),
(UPat.var('x', dtype=dtypes.bool) + UPat.var('y', dtype=dtypes.bool), lambda x,y: x|y),
(UPat.var('x', dtype=dtypes.bool).maximum(UPat.var('y', dtype=dtypes.bool)), lambda x,y: x|y),
# *** cast ***
(UPat(UOps.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
(UPat(UOps.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
(UPat(Ops.CAST, name="root", src=UPat.cvar("c")), lambda root, c: root.const_like(c.arg)),
(UPat(Ops.CAST, name="root"), lambda root: root.src[0] if root.dtype == root.src[0].dtype else None),
])
symbolic = symbolic_simple+PatternMatcher([
@@ -1063,7 +1063,7 @@ symbolic = symbolic_simple+PatternMatcher([
(UPat.var().where(UPat.var("val"), UPat.var("val")), lambda val: val),
(UPat.cvar("gate", vec=False).where(UPat.var("c0"), UPat.var("c1")), lambda gate, c0, c1: c0 if gate.arg else c1),
# ALU min==max -> CONST (slow!)
(UPat(UOps.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
(UPat(Ops.ALU, name="x"), lambda x: x.const_like(x.vmin) if x.vmin == x.vmax else None),
# max folding
(UPat.maximum(UPat.var("x"), UPat.var("y")), lambda x,y: x if x.vmin >= y.vmax else y if x.vmax <= y.vmin else None),
# TODO: why does this rule break beautiful_mnist?
@@ -1090,11 +1090,11 @@ symbolic = symbolic_simple+PatternMatcher([
(((UPat.cvar("c0", vec=False)*UPat.var("x"))+UPat.var("x2")).lt(UPat.cvar("c1", vec=False)),
lambda x,x2,c0,c1: x.lt(c1//c0) if c1.arg % c0.arg == 0 and c0.arg > x2.vmax and x2.vmin >= 0 else None),
# ** move add/mul consts to end (NOTE: this is still happening before constant folding) **
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
(UPat(UOps.ALU, arg=BinaryOps.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
(UPat(Ops.ALU, arg=BinaryOps.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
(UPat(Ops.ALU, arg=BinaryOps.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
# *** rules from symbolic ***
# unrolled arange div folding
(UPat(UOps.ALU, name="divs", src=[UPat(), UPat(UOps.ALU, arg=BinaryOps.IDIV)], arg=BinaryOps.ADD), fold_unrolled_divs),
(UPat(Ops.ALU, name="divs", src=[UPat(), UPat(Ops.ALU, arg=BinaryOps.IDIV)], arg=BinaryOps.ADD), fold_unrolled_divs),
# generic lt folding
(UPat.var("x", dtypes.sints).lt(UPat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg else None),
# canonicalize a simplex with positive coefficients > 0
@@ -1115,23 +1115,23 @@ symbolic_flat = symbolic+PatternMatcher([
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
])
_substitute = PatternMatcher([(UPat(tuple(UOps), name="x"), lambda ctx,x: ctx.get(x,None))])
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
# for debug
syms = { BinaryOps.ADD: "+", BinaryOps.SUB: "-", BinaryOps.IDIV: "//", BinaryOps.MOD: "%", BinaryOps.SHL: "<<", BinaryOps.SHR: ">>",
BinaryOps.MUL: "*", BinaryOps.CMPLT: "<", BinaryOps.CMPNE: "!=", BinaryOps.AND: "&", BinaryOps.OR: "|", BinaryOps.XOR: "^"}
renderer = PatternMatcher([
(UPat((UOps.DEFINE_VAR, UOps.SPECIAL), name="x"), lambda x: UOp(UOps.NOOP, arg=x.arg[0])),
(UPat(UOps.RANGE, name="x"), lambda x: UOp(UOps.NOOP, arg=f"ridx{x.arg[0]}")),
(UPat(UOps.CONST, name="x"), lambda x: UOp(UOps.NOOP, arg=str(x.arg))),
(UPat(UOps.BIND, src=UPat(UOps.NOOP), name="x"), lambda x: x.src[0]),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=UnaryOps.NEG), lambda x: UOp(UOps.NOOP, arg=f"(-{x.src[0].arg})")),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=BinaryOps.MAX), lambda x: UOp(UOps.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=TernaryOps.MULACC),
lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x", arg=TernaryOps.WHERE),
lambda x: UOp(UOps.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")),
(UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])),
(UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")),
(UPat(Ops.CONST, name="x"), lambda x: UOp(Ops.NOOP, arg=str(x.arg))),
(UPat(Ops.BIND, src=UPat(Ops.NOOP), name="x"), lambda x: x.src[0]),
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=UnaryOps.NEG), lambda x: UOp(Ops.NOOP, arg=f"(-{x.src[0].arg})")),
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=BinaryOps.MAX), lambda x: UOp(Ops.NOOP, arg=f"max({x.src[0].arg}, {x.src[1].arg})")),
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=TernaryOps.MULACC),
lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg}+{x.src[2].arg})")),
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x", arg=TernaryOps.WHERE),
lambda x: UOp(Ops.NOOP, arg=f"({x.src[1].arg} if {x.src[0].arg} else {x.src[2].arg})")),
(UPat(Ops.ALU, src=UPat(Ops.NOOP), name="x"), lambda x: UOp(Ops.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")),
])
# *** what was symbolic.py ***

View File

@@ -2,7 +2,7 @@ from typing import Optional, List, Tuple, Dict, Callable, Any
import functools
from dataclasses import dataclass, field
from tinygrad.helpers import to_function_name, dedup, prod
from tinygrad.ops import Op, UOps, UOp, flops_mem, sym_infer, sint, Variable
from tinygrad.ops import Op, Ops, UOp, flops_mem, sym_infer, sint, Variable
from tinygrad.dtype import DType
@dataclass(frozen=True)
@@ -42,10 +42,10 @@ class Program:
if not self._ran_post_init and self.uops is not None:
# single pass through the uops
for u in self.uops:
if u.op is UOps.DEFINE_VAR: self.vars.append(u)
if u.op is UOps.DEFINE_GLOBAL: self.globals.append(u.arg)
if u.op is UOps.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL])
if u.op is UOps.SPECIAL:
if u.op is Ops.DEFINE_VAR: self.vars.append(u)
if u.op is Ops.DEFINE_GLOBAL: self.globals.append(u.arg)
if u.op is Ops.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is Ops.DEFINE_GLOBAL])
if u.op is Ops.SPECIAL:
# NOTE: you have to set local_size and global_size to the base [1,1,1] outside this
if u.arg[0][0] == 'i': self.local_size = None
special_size = self.local_size if u.arg[0][0] == 'l' else self.global_size

View File

@@ -2,66 +2,66 @@ from __future__ import annotations
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast
import os, math
from collections import defaultdict, Counter
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOps, UOp, PatternMatcher, UPat, cast_float_to_bf16
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
from tinygrad.renderer import Renderer, TensorCore
base_rewrite = PatternMatcher([
(UPat(UOps.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]),
(UPat(UOps.ASSIGN, name="x"), lambda ctx,x: f"{ctx[x.src[0]]} = {ctx[x.src[1]]};"),
(UPat(UOps.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
(UPat((UOps.ENDIF, UOps.ENDRANGE)), lambda ctx: "}"),
(UPat(UOps.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"),
(UPat(Ops.DEFINE_ACC, name="x"), lambda ctx,x: ctx[x.src[0]]),
(UPat(Ops.ASSIGN, name="x"), lambda ctx,x: f"{ctx[x.src[0]]} = {ctx[x.src[1]]};"),
(UPat(Ops.IF, name="x"), lambda ctx,x: f"if ({ctx[x.src[0]]}) {{"),
(UPat((Ops.ENDIF, Ops.ENDRANGE)), lambda ctx: "}"),
(UPat(Ops.WMMA, name="x"), lambda ctx,x: f"__{x.arg[0]}({ctx[x.src[0]]}, {ctx[x.src[1]]}, {ctx[x.src[2]]})"),
# r method accesses
(UPat(UOps.RANGE, name="x"),
(UPat(Ops.RANGE, name="x"),
lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = {ctx[x.src[0]]}; {ctx[x]} < {ctx[x.src[1]]}; {ctx[x]}++) {{"),
(UPat(UOps.VECTORIZE, name="x"),
(UPat(Ops.VECTORIZE, name="x"),
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
(f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device == "CLANG" else f"({','.join([ctx[y] for y in x.src])})")),
(UPat(UOps.CAST, name="x"), lambda ctx,x: f"({ctx.render_dtype(x.dtype)})({ctx[x.src[0]]})"),
(UPat(UOps.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
(UPat(UOps.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.arg[1]}];"),
(UPat(UOps.BARRIER), lambda ctx: ctx.barrier),
(UPat(UOps.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]),
(UPat(UOps.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_dtype(x.dtype)})({ctx[x.src[0]]})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"),
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.arg[1]}];"),
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
(UPat(Ops.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]),
(UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"),
# const
(UPat(UOps.CONST, arg=math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)}){ctx.infinity})"),
(UPat(UOps.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)})-{ctx.infinity})"),
(UPat(UOps.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){ctx.nan})" if math.isnan(x.arg) else None),
(UPat(UOps.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"),
(UPat(UOps.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"),
(UPat(UOps.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"),
(UPat(UOps.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"),
(UPat(UOps.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
(UPat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)}){ctx.infinity})"),
(UPat(Ops.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)})-{ctx.infinity})"),
(UPat(Ops.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){ctx.nan})" if math.isnan(x.arg) else None),
(UPat(Ops.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"),
(UPat(Ops.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"),
(UPat(Ops.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"),
(UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"),
(UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"),
# consts are rendered to larger type and casted
(UPat(UOps.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}f)"),
(UPat(UOps.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}u)"),
(UPat(UOps.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg})"),
(UPat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}f)"),
(UPat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}u)"),
(UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg})"),
# default const render
(UPat(UOps.CONST, name="x"), lambda ctx,x: str(x.arg)),
(UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)),
# new load/store
(UPat(UOps.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == BinaryOps.ADD else ctx[idx]})"),
(UPat(UOps.LOAD, src=(UPat.var('bidx'), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
(UPat(UOps.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"),
(UPat(UOps.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
(UPat(Ops.LOAD, src=(UPat.var('bidx'), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"),
(UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"),
(UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"),
# alu/gep
(UPat(UOps.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.arg](
(UPat(Ops.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.arg](
*([strip_parens(ctx[v]) if v.arg == x.arg and x.arg in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else ctx[v] for v in x.src]), x.dtype)),
(UPat(UOps.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")),
])
extra_pm = PatternMatcher([
# insert a NOOP before BITCAST to force it to be rendered. not needed on all backends?
(UPat(UOps.BITCAST, name="x"),
lambda x: UOp(UOps.BITCAST, x.dtype, (UOp(UOps.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not UOps.NOOP else None),
(UPat(Ops.BITCAST, name="x"),
lambda x: UOp(Ops.BITCAST, x.dtype, (UOp(Ops.NOOP, x.src[0].dtype, x.src),)) if x.src[0].op is not Ops.NOOP else None),
# gate any stores that aren't gated with ifs
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
lambda store: UOp(UOps.STORE, src=store.src[:2]+(UOp(UOps.IF, src=(store.src[2],)),))),
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
# rewrite MAX to CMPLT + WHERE (max function is annoying on many cstyle backends)
(UPat(UOps.ALU, name="m", arg=BinaryOps.MAX), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
(UPat(Ops.ALU, name="m", arg=BinaryOps.MAX), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
])
def uops_to_dtypes(uops:List[UOp]) -> List[DType]: return dedup(u.dtype for u in uops if not isinstance(u.dtype, (ImageDType, PtrDType)))
@@ -125,41 +125,41 @@ class CStyleLanguage(Renderer):
depth = 1
c: DefaultDict[str, int] = defaultdict(int)
for u in uops:
if u.op in (UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR):
r[u] = f"data{u.arg}" if u.op is UOps.DEFINE_GLOBAL else u.arg[0]
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
r[u] = f"data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else u.arg[0]
bufs[u] = (r[u], (u.dtype, False))
continue
# mark buffers that we store to writable
if u.op is UOps.STORE:
if u.op is Ops.STORE:
for up in u.src[0].sparents:
if up.op is UOps.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True))
if up.op is Ops.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True))
# naming
prefix = None
if u.op is UOps.SPECIAL:
if u.op is Ops.SPECIAL:
r[u] = u.arg[0]
else:
prefix = {UOps.RANGE: "ridx", UOps.ALU: "alu", UOps.WMMA: "wmma", UOps.DEFINE_LOCAL: "temp", UOps.CONST: "const",
UOps.CAST: "cast", UOps.BITCAST: "cast", UOps.GEP: "gep", UOps.VECTORIZE: "cast", UOps.NOOP: "precast",
UOps.INDEX: "bidx", UOps.DEFINE_ACC: "acc", UOps.LOAD: "val"}.get(u.op, "unk")
prefix = {Ops.RANGE: "ridx", Ops.ALU: "alu", Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.VECTORIZE: "cast", Ops.NOOP: "precast",
Ops.INDEX: "bidx", Ops.DEFINE_ACC: "acc", Ops.LOAD: "val"}.get(u.op, "unk")
r[u] = f"{prefix}{c[prefix]}"
l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
if u.op in {UOps.ENDIF, UOps.ENDRANGE}: depth -= 1
if u.op in {UOps.CONST, UOps.GEP, UOps.INDEX} or (u.op in {UOps.VECTORIZE, UOps.ALU, UOps.CAST, UOps.BITCAST}
if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1
if u.op in {Ops.CONST, Ops.GEP, Ops.INDEX} or (u.op in {Ops.VECTORIZE, Ops.ALU, Ops.CAST, Ops.BITCAST}
and child_count[u] == 1 and not getenv("EXPAND_SSA")):
r[u] = l
else:
if u.op in {UOps.RANGE, UOps.ASSIGN, UOps.DEFINE_LOCAL} or u.dtype == dtypes.void:
if u.op is UOps.ASSIGN: r[u] = r[u.src[0]]
if u.op in {Ops.RANGE, Ops.ASSIGN, Ops.DEFINE_LOCAL} or u.dtype == dtypes.void:
if u.op is Ops.ASSIGN: r[u] = r[u.src[0]]
else:
l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not UOps.SPECIAL else "")
l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
kernel.append(" "*depth + l)
if prefix: c[prefix] += 1 # if it was used, increment
if u.op in {UOps.IF, UOps.RANGE}: depth += 1
if u.op in {Ops.IF, Ops.RANGE}: depth += 1
del self.r
# NOTE: this relies on bufs dict preserving order
@@ -189,7 +189,7 @@ class ClangRenderer(CStyleLanguage):
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
prefix = [self.render_vector_prefix(dt) for dt in uops_to_dtypes(uops) if dt.count > 1]
# https://github.com/corsix/amx
for name, (N, M, _), dtype_in, _, _, _, _, _ in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
for name, (N, M, _), dtype_in, _, _, _, _, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
prefix += [
'#define AMX_SET(imm5) __asm("nop\\nnop\\nnop\\n.word (0x201000+(%0<<5)+%1)" : : "i"(17), "i"(imm5) : "memory")',
'#define AMX(op, gpr, btf) __asm(".word (0x201000+(%0 << 5)+0%1-((0%1>>4)*6))" : : "i"(op), "r"((unsigned long long)(gpr)+(btf)) : "memory")',
@@ -214,13 +214,13 @@ class OpenCLRenderer(CStyleLanguage):
type_map = { dtypes.uint8: "uchar", dtypes.uint32: "uint", dtypes.uint16: "ushort", dtypes.uint64: "ulong", dtypes.bfloat16: "ushort" }
string_rewrite = PatternMatcher([
(UPat(UOps.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_{ctx.render_dtype(x.dtype)}({ctx[x.src[0]]})"),
# load/store image (OpenCL)
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var"), UPat.var("gate"))),
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var"), UPat.var("gate"))),
lambda ctx,buf,idx,var,gate: f"({ctx[gate]}?read_imagef({ctx[buf]}, smp, {ctx[idx]}):{ctx[var]})"),
(UPat(UOps.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)),
(UPat(Ops.LOAD, dtype=dtypes.float.vec(4), src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))),)),
lambda ctx,buf,idx: f"read_imagef({ctx[buf]}, smp, {ctx[idx]})"),
(UPat(UOps.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
(UPat(Ops.STORE, src=(UPat.var('buf').index(UPat.var('idx', dtypes.int.vec(2))), UPat.var("var", dtypes.float.vec(4))), allow_any_len=True),
lambda ctx,buf,idx,var: f"write_imagef({ctx[buf]}, {ctx[idx]}, {ctx[var]});"),
]) + base_rewrite
@@ -234,13 +234,13 @@ class IntelRenderer(OpenCLRenderer):
st1_pattern=(((1,0),),((1,2),(1,1),(0,0))),expanded_shape=(8,2,8)) for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
string_rewrite = PatternMatcher([
(UPat(UOps.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x[0]]})"),
(UPat(UOps.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x[0]]})"),
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x[0]]})"),
(UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x[0]]})"),
]) + OpenCLRenderer.string_rewrite
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
prefix = []
for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
dt_in = ("ushort", "bf16") if arg[2] == dtypes.bfloat16 else (arg[2].name, "f16")
prefix.append(f"""{arg[3].name}8 __{arg[0]}({dt_in[0]}16 a, {dt_in[0]}16 b, {arg[3].name}8 c) {{
return intel_sub_group_{dt_in[1]}_{dt_in[1]}_matrix_mad_k16(as_int8(a), as_int8(b), c);\n}}""")
@@ -272,17 +272,17 @@ class MetalRenderer(CStyleLanguage):
# upcast to float32 all the ops that don't support bfloat16
extra_matcher = PatternMatcher([
# NOTE: this is copied from PTX
*[(UPat(UOps.ALU, arg=op, dtype=dtypes.bfloat16, name="x"),
*[(UPat(Ops.ALU, arg=op, dtype=dtypes.bfloat16, name="x"),
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)))
for op in [UnaryOps.SQRT, UnaryOps.EXP2, UnaryOps.LOG2, UnaryOps.SIN]]
]) + extra_pm
string_rewrite = PatternMatcher([
(UPat(UOps.BITCAST, name="x"), lambda ctx,x: f"as_type<{ctx.render_dtype(x.dtype)}>({ctx[x.src[0]]})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"as_type<{ctx.render_dtype(x.dtype)}>({ctx[x.src[0]]})"),
]) + base_rewrite
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None):
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is UOps.WMMA])
prefix, wmma_args = ["#include <metal_stdlib>","using namespace metal;"], set([uop.arg for uop in uops if uop.op is Ops.WMMA])
for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{
simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x;
b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c);
@@ -335,7 +335,7 @@ class CUDARenderer(CStyleLanguage):
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count in (4,8) and dt.scalar() in {dtypes.half, dtypes.bfloat16}]
dt_map = { dtypes.half: "f16", dtypes.bfloat16: "bf16" }
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]):
for name, (N, M, K), dtype_in, dtype_out, _, _, upcast_axes, _ in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]):
upcast_sizes = [prod(size for _, size in upcast) for upcast in upcast_axes]
wmma_dtypes = [self.render_dtype(dtype.vec(size)) for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)]
n_operands = [size*dtype.itemsize//4 for dtype, size in zip([dtype_in, dtype_in, dtype_out], upcast_sizes)] # 4 => CUDA reg size in bytes
@@ -353,7 +353,7 @@ class CUDARenderer(CStyleLanguage):
return super().render_kernel(function_name, kernel, bufs, uops, prefix=prefix)
def get_kernel_modifier(self, uops:List[UOp]) -> str:
maxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is UOps.SPECIAL and u.arg[0][0] == "l")
maxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html
return f"__launch_bounds__({maxThreadsPerBlock}) "
@@ -387,20 +387,20 @@ class AMDRenderer(CStyleLanguage):
type_map = {dtypes.bfloat16: "hip_bfloat16"}
extra_matcher = PatternMatcher([
# cast bfloat16 alus to float
(UPat(UOps.ALU, arg=TernaryOps.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
lambda b,x,y: UOp(UOps.ALU, arg=TernaryOps.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)),
(UPat(UOps.ALU, dtype=dtypes.bfloat16, name="x"),
(UPat(Ops.ALU, arg=TernaryOps.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
lambda b,x,y: UOp(Ops.ALU, arg=TernaryOps.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)),
(UPat(Ops.ALU, dtype=dtypes.bfloat16, name="x"),
lambda x: UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)),
(UPat(UOps.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
(UPat(Ops.ALU, dtypes.bool, name="alu", src=(UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
lambda alu,x,y: UOp(alu.op, dtypes.bool, (x.cast(dtypes.float), y.cast(dtypes.float)), alu.arg)),
# add float intermediate casting for bfloat16
(UPat(UOps.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
(UPat(UOps.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
(UPat(Ops.CAST, name="x", src=UPat.var("y", dtypes.bfloat16)),lambda x,y: y.cast(dtypes.float).cast(x.dtype) if x.dtype!=dtypes.float else None),
(UPat(Ops.CAST, dtypes.bfloat16, UPat.var("x")),lambda x: x.cast(dtypes.float).cast(dtypes.bfloat16) if x.dtype!=dtypes.float else None),
# bfloat16 casting
(UPat.cvar('x', dtypes.bfloat16), lambda x: cast_float_to_bf16(UOp.const(dtypes.float, x.arg))),
(UPat(UOps.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)),
(UPat(Ops.CAST, dtype=dtypes.float, src=UPat.var("x", dtype=dtypes.bfloat16)),
lambda x: (x.bitcast(dtypes.ushort).cast(dtypes.uint)<<16).bitcast(dtypes.float)),
(UPat(UOps.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=UPat.var("x", dtype=dtypes.float)), cast_float_to_bf16)]) + extra_pm
def render_vector_prefix(self, dtype:DType) -> str:
vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar())
@@ -414,7 +414,7 @@ class AMDRenderer(CStyleLanguage):
if any(dt.scalar() == dtypes.bfloat16 for dt in used_dtypes): prefix.append("struct hip_bfloat16 { unsigned short data; };")
prefix += [self.render_vector_prefix(dt) for dt in used_dtypes if dt.count > 1]
for arg in dedup([uop.arg for uop in uops if uop.op is UOps.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
for arg in dedup([uop.arg for uop in uops if uop.op is Ops.WMMA]): # TODO: handle TCs f32_bf16 and bf16_bf16 w/ wrapper
if arg[3] == dtypes.float: prefix.append(f"#define __{arg[0]} __builtin_amdgcn_wmma_f32_16x16x16_f16_w32")
else: prefix.append(f"static inline __attribute__((device)) half8 __{arg[0]}"+"""(half16 a, half16 b, half8 c) {
half16 c_frag = {}; half8 d; for (int n = 0; n < 8; n++) { c_frag[n*2] = c[n]; }
@@ -423,7 +423,7 @@ class AMDRenderer(CStyleLanguage):
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
def get_kernel_modifier(self, uops:List[UOp]) -> str:
requiredMaxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is UOps.SPECIAL and u.arg[0][0] == "l")
requiredMaxThreadsPerBlock = prod(u.arg[1] for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == "l")
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"

View File

@@ -1,7 +1,7 @@
from typing import Dict, Callable, List, Optional
from llvmlite import ir
from tinygrad.dtype import DType, PtrDType, dtypes
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps, UOps, UOp
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps, Ops, UOp
from tinygrad.renderer import Renderer
MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
@@ -71,7 +71,7 @@ class LLVMRenderer(Renderer):
module = ir.Module(name=__file__)
# extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}}
buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}}
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
# create llvm function
@@ -90,14 +90,14 @@ class LLVMRenderer(Renderer):
for u in uops:
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
if uop is UOps.INDEX:
if uop is Ops.INDEX:
lvars[u] = bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True)
elif uop is UOps.STORE:
elif uop is Ops.STORE:
if len(src) > 2:
with bb[-1].if_then(lvars[src[2]]): bb[-1].store(lvars[src[1]], lvars[src[0]])
else:
bb[-1].store(lvars[src[1]], lvars[src[0]])
elif uop is UOps.ENDRANGE:
elif uop is Ops.ENDRANGE:
loop_entry_bb, phis = loop_blocks.pop()
idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))
lvars[src[0]].add_incoming(idx_p1, bb[-1].block)
@@ -105,7 +105,7 @@ class LLVMRenderer(Renderer):
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block)
else:
if uop is UOps.RANGE:
if uop is Ops.RANGE:
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
bb[-2].branch(bb[-1].block)
@@ -119,10 +119,10 @@ class LLVMRenderer(Renderer):
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
lvars[u].add_incoming(lvars[src[0]], bb[-2].block)
loop_blocks.append((bb[-1].block, phis))
elif uop is UOps.DEFINE_ACC:
elif uop is Ops.DEFINE_ACC:
lvars[u] = const(src[0].arg, dtype)
reduce_phis.append(u)
elif uop is UOps.LOAD:
elif uop is Ops.LOAD:
if len(src) > 1:
with bb[-1].if_else(lvars[src[2]]) as (then, otherwise):
with then:
@@ -135,17 +135,17 @@ class LLVMRenderer(Renderer):
else:
val = bb[-1].load(lvars[src[0]])
lvars[u] = val
elif uop is UOps.ASSIGN:
elif uop is Ops.ASSIGN:
lvars[u] = lvars[src[1]]
# ASSIGN UOps can link to other ASSIGN Uops, backtrace this to DEFINE_ACC
backward = src[0]
while backward.op is UOps.ASSIGN: backward = backward.src[0]
while backward.op is Ops.ASSIGN: backward = backward.src[0]
lvars[backward] = lvars[u]
elif uop is UOps.ALU:
elif uop is Ops.ALU:
lvars[u] = self.code_for_op[args](bb[-1], *[lvars[x] for x in src], src[0].dtype if args in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype)
elif uop in {UOps.CAST, UOps.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is UOps.BITCAST)
elif uop in {UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
elif uop is UOps.CONST: lvars[u] = const(args, dtype)
elif uop in {Ops.CAST, Ops.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is Ops.BITCAST)
elif uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
elif uop is Ops.CONST: lvars[u] = const(args, dtype)
else: raise RuntimeError(f"failed to render {uop}")
bb[-1].ret_void()

View File

@@ -1,7 +1,7 @@
from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable
import struct
from collections import defaultdict
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, UOps, UOp, PatternMatcher, UPat
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op, Ops, UOp, PatternMatcher, UPat
from tinygrad.dtype import dtypes, DType, PtrDType, ConstType
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer
@@ -38,20 +38,20 @@ ptx_matcher = PatternMatcher([
(UPat.var('x', dtype=dtypes.bool).ne(UPat.var('y')), lambda x,y: x^y),
(UPat.var('x', dtype=dtypes.bool).lt(UPat.var('y')), lambda x,y: (x^True)&y),
# upcast to float32 all the ops that don't support half
*[(UPat(UOps.ALU, arg=op, dtype=dtypes.half, name="x"),
*[(UPat(Ops.ALU, arg=op, dtype=dtypes.half, name="x"),
lambda x: (UOp(x.op, dtypes.float32, tuple(vv.cast(dtypes.float32) for vv in x.src), x.arg).cast(dtypes.half)))
for op in asm_for_op.keys() if op not in supports_half],
# load/store bool -> uint8
(UPat(UOps.LOAD, dtypes.bool, src=(UPat(dtype=dtypes.int64),), name="x", allow_any_len=True),
(UPat(Ops.LOAD, dtypes.bool, src=(UPat(dtype=dtypes.int64),), name="x", allow_any_len=True),
lambda x: UOp(x.op, dtypes.uint8, x.src[0:1] + ((x.src[1].cast(dtypes.uint8),) if len(x.src) >= 2 else ()) + x.src[2:]).cast(dtypes.bool)),
(UPat(UOps.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
(UPat(Ops.STORE, src=(UPat(dtype=dtypes.int64), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
lambda x: UOp(x.op, dtypes.void, x.src[0:1] + (x.src[1].cast(dtypes.uint8),) + x.src[2:])),
# load/store use pointer arithmetic, and the cast does nothing
(UPat(UOps.INDEX, src=(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize),
(UPat(UOps.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))), lambda buf,idx: buf.cast(dtypes.int64) + idx.cast(dtypes.int64)*buf.dtype.itemsize),
(UPat(Ops.CAST, name="x"), lambda x: x.src[0] if isinstance(x.dtype, PtrDType) else None),
# ptx shr and shl instructions require y to be uint
(UPat.var("x") << UPat.var("y"), lambda x,y: UOp(UOps.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHL) if y.dtype != dtypes.uint else None),
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(UOps.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHR) if y.dtype != dtypes.uint else None),
(UPat.var("x") << UPat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHL) if y.dtype != dtypes.uint else None),
(UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.ALU, x.dtype, (x,y.cast(dtypes.uint)), BinaryOps.SHR) if y.dtype != dtypes.uint else None),
])
class PTXRenderer(Renderer):
@@ -146,51 +146,51 @@ class PTXRenderer(Renderer):
for u in uops:
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
if uop is UOps.IF:
if uop is Ops.IF:
pred_reg = _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True)
kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{uops.index(u)}", pred_reg, invert=True))
elif uop is UOps.BARRIER and self.barrier: kk(self.barrier)
elif uop is UOps.ENDRANGE:
elif uop is Ops.BARRIER and self.barrier: kk(self.barrier)
elif uop is Ops.ENDRANGE:
kk(self.code_for_op[BinaryOps.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]),
self.code_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int]))
kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred))
elif uop is UOps.ENDIF:
elif uop is Ops.ENDIF:
kk(f"IF_{r[src[0].src[0]][1:]}_{uops.index(src[0])}:")
elif uop is UOps.STORE:
elif uop is Ops.STORE:
assert src[0].dtype == dtypes.int64, "store isn't int64"
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
gate = f"@{r[src[2]]} " if len(src)>2 and src[2].op is not UOps.IF else ""
mem_type = '.shared' if src[0].op is Ops.DEFINE_LOCAL or any(x.op is Ops.DEFINE_LOCAL for x in src[0].parents) else '.global'
gate = f"@{r[src[2]]} " if len(src)>2 and src[2].op is not Ops.IF else ""
if src[1].dtype.count > 1:
kk(gate + f"st{mem_type}.v{src[1].dtype.count}.{self.mem_types[src[1].dtype.scalar()]} [{r[src[0]]}+0], {{{', '.join(r[src[1]])}}};")
else:
kk(gate + f"st{mem_type}.{self.mem_types[src[1].dtype]} [{r[src[0]]}+0], {r[src[1]]};")
else:
if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
elif uop is UOps.ALU:
if uop is Ops.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
elif uop is Ops.ALU:
src_dtype = src[0].dtype if args in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype
kk(self.code_for_op[args](ssa("alu", u), *[r[x] for x in src], src_dtype, self.types[src_dtype]))
elif uop is UOps.DEFINE_ACC:
elif uop is Ops.DEFINE_ACC:
if dtype.count > 1:
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].src[0].arg, dtype.scalar())};")
else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
elif uop is UOps.SPECIAL:
elif uop is Ops.SPECIAL:
assert args[0][0] != "i", "idx not supported"
kk(f"mov.u32 %{args[0]}, %{'ctaid' if args[0][0] == 'g' else 'tid'}.{chr(120+int(args[0][-1]))};")
r[u] = "%" + args[0]
kernel = [f".reg .u32 %{args[0]};"] + kernel
elif uop is UOps.DEFINE_VAR:
elif uop is Ops.DEFINE_VAR:
bufs.append((args[0], dtype))
r[u] = f"%{args[0]}"
kk(*self.render_load(args[0], ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
elif uop is UOps.CONST: r[u] = const(args, dtype, mov=True)
elif uop is UOps.GEP:
elif uop is Ops.CONST: r[u] = const(args, dtype, mov=True)
elif uop is Ops.GEP:
assert len(u.arg) == 1
r[u] = r[src[0]][u.arg[0]]
elif uop is UOps.LOAD:
elif uop is Ops.LOAD:
assert src[0].dtype == dtypes.int64, "load isn't int64"
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
has_gate = len(src) > 2 and src[2].op is UOps.ALU
mem_type = '.shared' if src[0].op is Ops.DEFINE_LOCAL or any(x.op is Ops.DEFINE_LOCAL for x in src[0].parents) else '.global'
has_gate = len(src) > 2 and src[2].op is Ops.ALU
if dtype.count > 1:
r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
if has_gate:
@@ -200,25 +200,25 @@ class PTXRenderer(Renderer):
else:
kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if has_gate else None,
alt=r[src[1]] if has_gate else None, ss=mem_type, offset=0))
elif uop is UOps.ASSIGN:
elif uop is Ops.ASSIGN:
if dtype.count > 1:
for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};")
else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {r[src[0]]}, {r[src[1]]};")
r[u] = r[src[0]]
# NOTE: casting to str is fine because you can't vectorize a vectorize
elif uop is UOps.VECTORIZE: r[u] = [cast(str,r[x]) for x in src]
elif uop in {UOps.CAST, UOps.BITCAST}:
_cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
elif uop is UOps.DEFINE_LOCAL:
elif uop is Ops.VECTORIZE: r[u] = [cast(str,r[x]) for x in src]
elif uop in {Ops.CAST, Ops.BITCAST}:
_cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is Ops.BITCAST, u=u)
elif uop is Ops.DEFINE_LOCAL:
# TODO: we should sum these, and fetch 0xC000 from somewhere
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype))
elif uop is UOps.DEFINE_GLOBAL:
elif uop is Ops.DEFINE_GLOBAL:
bufs.append((nm:=f"data{args}", dtype))
r[u] = f"%{nm}"
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param"))
elif uop is UOps.WMMA:
elif uop is Ops.WMMA:
_, (N, M, K), dtype_in, _, _, _, upcast_axes, _ = args
wmma, n_operands = [], tuple(prod(sz for _, sz in upc)*dtype_in.itemsize//4 for upc in upcast_axes[:2])
dt_map = { dtypes.half: "f16" }

View File

@@ -7,7 +7,7 @@ import pickle, base64, itertools, time, struct
from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate
from tinygrad.helpers import all_same, getenv, flatten
from tinygrad.device import Compiled, Compiler, Allocator
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, UOps, UOp
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, Ops, UOp
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer, IntelRenderer, ClangRenderer
@@ -26,7 +26,7 @@ def _store(m, i, v):
class PythonProgram:
def __init__(self, name:str, lib:bytes):
self.uops: List[Tuple[UOps, Optional[DType], List[int], Any]] = pickle.loads(lib)
self.uops: List[Tuple[Ops, Optional[DType], List[int], Any]] = pickle.loads(lib)
def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
st = time.perf_counter()
warp = list(itertools.product(*[range(x) for x in local_size[::-1]]))
@@ -40,12 +40,12 @@ class PythonProgram:
loop_ends: Dict[int, int] = {}
while i < len(self.uops):
uop, dtype, idp, arg = self.uops[i]
void_ops = {UOps.STORE, UOps.ENDRANGE, UOps.BARRIER, UOps.IF, UOps.ENDIF}
if uop is UOps.DEFINE_ACC: idp = [idp[0]]
void_ops = {Ops.STORE, Ops.ENDRANGE, Ops.BARRIER, Ops.IF, Ops.ENDIF}
if uop is Ops.DEFINE_ACC: idp = [idp[0]]
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
if uop is UOps.STORE:
if uop is Ops.STORE:
if len(inp) == 2: inp.append([True] * len(inp[0])) # set the gate to True
if dtp[1].count > 1:
for j,val in enumerate(inp[1]):
@@ -56,32 +56,32 @@ class PythonProgram:
if g: _store(m, o, v)
i += 1
continue
if uop is UOps.ENDRANGE:
if uop is Ops.ENDRANGE:
loop_ends[idp[0]] = i
i = idp[0]
continue
if uop in (UOps.BARRIER, UOps.IF, UOps.ENDIF):
if uop in (Ops.BARRIER, Ops.IF, Ops.ENDIF):
# in the python emulator, the warp is always in sync
i += 1
continue
assert dtype is not None, f"{uop} is missing a dtype"
dl[i] = dtype
if uop is UOps.DEFINE_GLOBAL:
if uop is Ops.DEFINE_GLOBAL:
assert dtype.fmt is not None
ul[i] = [pbufs.pop(0).cast(dtype.fmt)] * warp_size
elif uop is UOps.DEFINE_LOCAL:
elif uop is Ops.DEFINE_LOCAL:
assert dtype.fmt is not None
lbuf = memoryview(bytearray(arg[1]*dtype.itemsize))
ul[i] = [lbuf.cast(dtype.fmt)] * warp_size
elif uop is UOps.DEFINE_VAR:
elif uop is Ops.DEFINE_VAR:
ul[i] = [pvals.pop(0)] * warp_size
elif uop is UOps.SPECIAL:
elif uop is Ops.SPECIAL:
if arg[0][0] == 'g': ul[i] = [idxs[2-int(arg[0][-1])]] * warp_size
elif arg[0][0] == 'l': ul[i] = [x[2-int(arg[0][-1])] for x in warp]
elif uop is UOps.CONST: ul[i] = [arg] * warp_size
elif uop is UOps.DEFINE_ACC:
elif uop is Ops.CONST: ul[i] = [arg] * warp_size
elif uop is Ops.DEFINE_ACC:
ul[i] = [[inp[0][0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
elif uop is UOps.INDEX:
elif uop is Ops.INDEX:
ret = []
if isinstance(dtp[0], ImageDType):
for m,ox,oy in zip(inp[0], inp[1][0], inp[1][1]):
@@ -90,9 +90,9 @@ class PythonProgram:
else:
for m,o in zip(inp[0], inp[1]): ret.append((m,o))
ul[i] = ret
elif uop is UOps.CAST and isinstance(dtype, PtrDType):
elif uop is Ops.CAST and isinstance(dtype, PtrDType):
ul[i] = inp[0]
elif uop is UOps.RANGE:
elif uop is Ops.RANGE:
if i not in ul: ul[i] = [inp[0][0]] * warp_size
else:
for j in range(len(ul[i])):
@@ -101,11 +101,11 @@ class PythonProgram:
del ul[i]
i = loop_ends[i] + 1
continue
elif uop is UOps.VECTORIZE: ul[i] = inp
elif uop in {UOps.CAST, UOps.BITCAST}:
elif uop is Ops.VECTORIZE: ul[i] = inp
elif uop in {Ops.CAST, Ops.BITCAST}:
assert dtp[0].fmt and dtype.fmt
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
if uop is Ops.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
else:
casted = [dtypes.as_const(x, dtype) for x in inp[0]]
if dtypes.is_int(dtype):
@@ -114,18 +114,18 @@ class PythonProgram:
elif dtypes.is_float(dtype):
casted = [truncate.get(dtype, lambda dt: dt)(x) for x in casted]
ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted)))
elif uop is UOps.LOAD:
elif uop is Ops.LOAD:
if dtype.count > 1:
ul[i] = [load([inp[i][j] if i != 0 and dtp[i].count > 1 else inp[i] for i in range(len(inp))], j) for j in range(dtype.count)]
else:
ul[i] = load(inp)
elif uop is UOps.ASSIGN:
elif uop is Ops.ASSIGN:
for j in range(len(inp[0])): inp[0][j] = inp[1][j]
ul[i] = inp[0]
elif uop is UOps.GEP:
elif uop is Ops.GEP:
assert len(arg) == 1
ul[i] = inp[0][arg[0]]
elif uop is UOps.WMMA:
elif uop is Ops.WMMA:
# here are the models for the WMMA instruction on the different hardware
def wmma_helper(WARP_THREADS, K, NUM_A, NUM_B, NUM_C, a_elem, b_elem, c_map):
assert len(inp[0]) == NUM_A, f"A must have {NUM_A} elements per thread, it has {len(inp[0])}"
@@ -180,7 +180,7 @@ class PythonProgram:
def c_map(_, elem): return (elem%16, elem//16)
ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
elif uop is UOps.ALU:
elif uop is Ops.ALU:
assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {arg}"
assert all_same([dtype] + dtp) or arg in {BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {arg}"
ul[i] = [exec_alu(arg, dtype, p) for p in zip(*inp)]

View File

@@ -5,7 +5,7 @@ from typing import Tuple, List, Optional, Dict, Set
from tinygrad.helpers import merge_dicts, getenv
from tinygrad.shape.view import View, strides_for_shape
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, split_uop, symbolic_flat, Variable, sint
from tinygrad.ops import UOp, Ops, BinaryOps, graph_rewrite, split_uop, symbolic_flat, Variable, sint
@dataclass(frozen=True, order=True)
class ShapeTracker:
@@ -40,7 +40,7 @@ class ShapeTracker:
def reduce(self, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
def to_uop(self) -> UOp: return UOp(UOps.VIEW, dtypes.void, (), self)
def to_uop(self) -> UOp: return UOp(Ops.VIEW, dtypes.void, (), self)
def to_indexed_uops(self, _idxs:Optional[List[UOp]]=None) -> Tuple[UOp, UOp]:
idx, valid = self.views[-1].to_indexed_uops(_idxs)
@@ -75,20 +75,20 @@ class ShapeTracker:
ret: List[Optional[sint]] = [None] * len(self.shape)
idx, valid = (graph_rewrite(u, symbolic_flat) for u in self.to_indexed_uops())
for c in split_uop(idx, BinaryOps.ADD):
if c.op is UOps.RANGE: ret[c.arg] = 1
if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[0].op is UOps.RANGE and c.src[1].op is UOps.CONST: ret[c.src[0].arg] = c.src[1].arg
if c.op is UOps.ALU and c.arg is BinaryOps.MUL and c.src[1].op is UOps.RANGE and c.src[0].op is UOps.CONST: ret[c.src[1].arg] = c.src[0].arg
used_ranges = [x.arg for x in idx.sparents if x.op is UOps.RANGE]
if c.op is Ops.RANGE: ret[c.arg] = 1
if c.op is Ops.ALU and c.arg is BinaryOps.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg] = c.src[1].arg
if c.op is Ops.ALU and c.arg is BinaryOps.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg] = c.src[0].arg
used_ranges = [x.arg for x in idx.sparents if x.op is Ops.RANGE]
ret = [x if i in used_ranges else 0 for i,x in enumerate(ret)]
if not ignore_valid:
for masked_axis in [x.arg for x in valid.sparents if x.op is UOps.RANGE]: ret[masked_axis] = None
for masked_axis in [x.arg for x in valid.sparents if x.op is Ops.RANGE]: ret[masked_axis] = None
return tuple(ret)
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
def axis_is_masked(self, axis:int) -> bool:
_, valid = self.to_indexed_uops()
return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).sparents if x.op is UOps.RANGE]
return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).sparents if x.op is Ops.RANGE]
def simplify(self) -> ShapeTracker:
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:

View File

@@ -9,7 +9,7 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, leas
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN
from tinygrad.multi import MultiLazyBuffer
from tinygrad.ops import MetaOps, smax, smin, resolve, UOp, UOps, BinaryOps, sint, Variable, SimpleMathTrait
from tinygrad.ops import MetaOps, smax, smin, resolve, UOp, Ops, BinaryOps, sint, Variable, SimpleMathTrait
from tinygrad.device import Device, Buffer, BufferOptions
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.engine.realize import run_schedule
@@ -136,7 +136,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
elif isinstance(data, get_args(ConstType)): data = _metaop(MetaOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data)
elif isinstance(data, UOp):
assert data.op is UOps.BIND and data.src[0].op is UOps.DEFINE_VAR and data.src[1].op is UOps.CONST, f"can't create tensor from UOp {data}"
assert data.op is Ops.BIND and data.src[0].op is Ops.DEFINE_VAR and data.src[1].op is Ops.CONST, f"can't create tensor from UOp {data}"
data = _metaop(MetaOps.CONST, tuple(), dtype or data.dtype, device, data)
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype)
elif isinstance(data, (list, tuple)):
@@ -382,9 +382,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
@staticmethod
def from_uop(y:UOp, **kwargs) -> Tensor:
if y.op is UOps.BIND: return Tensor(y, **kwargs, requires_grad=False) # this is the only UOp allowed in Tensor
if y.op is UOps.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
if y.op is UOps.ALU:
if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False) # this is the only UOp allowed in Tensor
if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False)
if y.op is Ops.ALU:
if y.arg is BinaryOps.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1])
if y.arg is BinaryOps.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1])
if y.arg is BinaryOps.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1]))

View File

@@ -5,13 +5,13 @@ from urllib.parse import parse_qs, urlparse
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Tuple, Optional
from tinygrad.helpers import colored, getenv, to_function_name, tqdm, unwrap, word_wrap
from tinygrad.ops import TrackedRewriteContext, UOp, UOps, lines
from tinygrad.ops import TrackedRewriteContext, UOp, Ops, lines
from tinygrad.codegen.kernel import Kernel
uops_colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.CONST: "#e0e0e0", UOps.VCONST: "#e0e0e0",
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0", UOps.REDUCE: "#C4A484",
UOps.RANGE: "#c8a0e0", UOps.ASSIGN: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0", UOps.SPECIAL: "#c0c0ff",
UOps.INDEX: "#e8ffa0", UOps.WMMA: "#efefc0", UOps.VIEW: "#C8F9D4", UOps.REDUCE_AXIS: "#f58488"}
uops_colors = {Ops.ALU: "#ffffc0", Ops.LOAD: "#ffc0c0", Ops.STORE: "#c0ffc0", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0",
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE: "#C4A484",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.REDUCE_AXIS: "#f58488"}
# ** API spec
@@ -50,7 +50,7 @@ def get_metadata(contexts:List[Tuple[Any, List[TrackedRewriteContext]]]) -> List
for k,ctxs in contexts:
name = to_function_name(k.name) if isinstance(k, Kernel) else k
for ctx in ctxs:
if ctx.sink.op is UOps.CONST: continue
if ctx.sink.op is Ops.CONST: continue
upats = [(upat.location, upat.printable(), tm) for _,_,upat,tm in ctx.matches if upat is not None]
if name not in kernels: kernels[name] = []
kernels[name].append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats)))
@@ -60,11 +60,11 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
assert isinstance(x, UOp)
graph: Dict[int, Tuple[str, str, List[int], str, str]] = {}
for u in x.sparents:
if u.op is UOps.CONST: continue
if u.op is Ops.CONST: continue
label = f"{str(u.op)[5:]}{(' '+word_wrap(str(u.arg).replace(':', ''))) if u.arg is not None else ''}\n{str(u.dtype)}"
for idx,x in enumerate(u.src):
if x.op is UOps.CONST: label += f"\nCONST{idx} {x.arg:g}"
graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x.op is not UOps.CONST], str(u.arg), uops_colors.get(u.op, "#ffffff"))
if x.op is Ops.CONST: label += f"\nCONST{idx} {x.arg:g}"
graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x.op is not Ops.CONST], str(u.arg), uops_colors.get(u.op, "#ffffff"))
return graph
def _replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp:
if (found:=replaces.get(base)) is not None: return found
@@ -88,7 +88,7 @@ def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata)
if new_sink is sink:
raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}")
# update ret data
g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not UOps.CONST])
g.changed_nodes.append([id(x) for x in u1.sparents if x.op is not Ops.CONST])
g.diffs.append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())))
g.graphs.append(sink:=new_sink)
return g